diff --git a/.gitignore b/.gitignore index 10a4262aa7e129c48d79fbe7d978720b28f4bcea..369fa1cb919c82caec326d1429c8a2eba3b928d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +python/paddle/fluid/tests/unittests/reader_reset_test.recordio paddle/operators/check_t.save paddle/operators/check_tensor.ls paddle/operators/tensor.save diff --git a/AUTHORS.md b/AUTHORS.md index 4060f75613ac4dadf353ff53a73fd0647a8052be..54a1097b50f7a09062f8987e62db6b5f5e89e0b7 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -42,6 +42,7 @@ | QiJune | Jun Qi | | qingqing01 | Qing-Qing Dang | | reyoung | Yang Yu | +| Sand3r- | Michal Gallus | | Superjom | Chun-Wei Yan | | tensor-tang | Jian Tang | | tianbingsz | Tian-Bing Xu | diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 729bdcb3dc5324df0a5272402ef203012be0072a..7355b67ab1020f58760f23b1a20ca189591db35e 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -166,8 +166,8 @@ copy(framework_lib DEPS ${framework_lib_deps} set(module "memory") copy(memory_lib - SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/detail/*.h - DSTS ${dst_dir}/${module} ${dst_dir}/${module}/detail + SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/detail/*.h ${src_dir}/${module}/allocation/*.h + DSTS ${dst_dir}/${module} ${dst_dir}/${module}/detail ${dst_dir}/${module}/allocation ) set(inference_deps paddle_fluid_shared paddle_fluid) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 40b0130b265471a1288d966c4cbcd4f0e1bdb9f1..6918e030bf859bc8a55baed9d944e16217b0efb6 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -100,6 +100,7 @@ class OperatorBase { const std::string& Type() const { return type_; } + bool HasAttr(const std::string& name) const { return attrs_.count(name); } template inline const T& Attr(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index eb89fc5e1124e97b082d6299e3efc44591a8b01b..0c73778b201d77a6e8a35a38d17f2a86d5faaca9 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -7,16 +7,17 @@ set(analysis_deps # analysis_deps can be extended accross the project add_subdirectory(ir_passes) add_subdirectory(passes) -cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass ${INFER_IR_PASSES}) +cc_library(analysis_helper SRCS helper.cc DEPS framework_proto proto_desc graph paddle_fluid_api) + +cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass ${INFER_IR_PASSES} analysis_helper) cc_library(argument SRCS argument.cc DEPS scope proto_desc) cc_library(analysis_pass SRCS analysis_pass.cc DEPS proto_desc) cc_library(analysis SRCS analyzer.cc - helper.cc analysis_pass - DEPS ${analysis_deps} + DEPS ${analysis_deps} analysis_helper ) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 48fc5dda2a5bfa24d679d4bf655e580dafc614b3..84a0c3374c66f85313828332099cb372e14c7c83 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -30,6 +30,7 @@ TEST(Analyzer, analysis_without_tensorrt) { Argument argument; argument.SetModelDir(FLAGS_inference_model_dir); argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); + argument.SetUseGPU(false); Analyzer analyser; analyser.Run(&argument); @@ -41,6 +42,7 @@ TEST(Analyzer, analysis_with_tensorrt) { argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetModelDir(FLAGS_inference_model_dir); argument.SetIrAnalysisPasses({"infer_clean_graph_pass"}); + argument.SetUseGPU(false); Analyzer analyser; analyser.Run(&argument); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index d7a2f3d1e3a3251263c8670aef5db538fa2c48ea..21203e2d9f4e4cd22ea49ea7b6808aff07e70eff 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -116,6 +116,7 @@ struct Argument { std::vector); DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); + DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller, std::function); diff --git a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt index c71cff889ed7cdb95f79b9bc89a9ca5ab370271c..822c7799bb3ae6d79da6cf2a7b3c8c9b20353ed7 100644 --- a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt @@ -4,4 +4,6 @@ set(analysis_deps ${analysis_deps} subgraph_detector tensorrt_subgraph_pass CACHE INTERNAL "") +set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) +file(APPEND ${pass_file} "USE_PASS(tensorrt_subgraph_pass);\n") set(INFER_IR_PASSES ${INFER_IR_PASSES} tensorrt_subgraph_pass CACHE INTERNAL "") diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc index 267737e95cb51f9a2d72b70b171be779d6f52c8d..108cb6f74b1208395a4faabdf6184152c300d244 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc @@ -46,7 +46,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul", "dropout", "split", "prelu", - "conv2d_transpose"}); + "conv2d_transpose", "leaky_relu"}); if (!node->IsOp()) return false; if (teller_set.count(node->Op()->Type())) { diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index a30fef08b5726c965637e2fb489bdb2036bd2a8d..d5e0d90de1da8e54e2411c266f7a8c09c33b0336 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -30,15 +30,28 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { if (!argument->scope_valid()) { argument->SetScope(new framework::Scope); } + PADDLE_ENFORCE(argument->use_gpu_valid()); + + // The load program should run on the same device with the inference program, + // so that the parameters will on the same device, or they will keep copying + // between difference devices. + platform::Place place; + if (argument->use_gpu()) { + PADDLE_ENFORCE(argument->gpu_device_id_valid()); + place = platform::CUDAPlace(argument->gpu_device_id()); + } else { + place = platform::CPUPlace(); + } if (argument->model_dir_valid()) { - auto program = LoadModel(argument->model_dir(), argument->scope_ptr()); + auto program = + LoadModel(argument->model_dir(), argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { auto program = LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr()); + argument->scope_ptr(), place); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( @@ -52,16 +65,15 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { } std::unique_ptr IrGraphBuildPass::LoadModel( - const std::string &path, framework::Scope *scope) { - platform::CPUPlace place; + const std::string &path, framework::Scope *scope, + const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, path); } std::unique_ptr IrGraphBuildPass::LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope) { - platform::CPUPlace place; + framework::Scope *scope, const platform::Place &place) { framework::Executor exe(place); return Load(&exe, scope, program_path, params_path); } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h index 3291e4f6ad3ca3079e672350805cab1f1e7b2413..271e64fce579bc9001b1dd632576571cec949752 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/analysis_pass.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace inference { @@ -32,11 +33,12 @@ class IrGraphBuildPass : public AnalysisPass { std::string repr() const override; private: - std::unique_ptr LoadModel(const std::string &path, - framework::Scope *scope); + std::unique_ptr LoadModel( + const std::string &path, framework::Scope *scope, + const platform::Place &place); std::unique_ptr LoadModel( const std::string &program_path, const std::string ¶ms_path, - framework::Scope *scope); + framework::Scope *scope, const platform::Place &place); std::string model_binary_str_; }; diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 82f74a269a5915dfa1d97a28f5ae15a12ea0b154..e9969b84f33483b048951f704de1e13e51cbeaea 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -27,11 +27,10 @@ endif() cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) cc_library(analysis_config SRCS analysis_config.cc DEPS lod_tensor paddle_pass_builder) cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config analysis_config paddle_pass_builder) -cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder) -cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) -cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api) - +cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager) +cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS scope lod_tensor enforce) +cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config analysis_config paddle_pass_builder DEPS zero_copy_tensor) cc_test(test_paddle_inference_api SRCS api_tester.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d19505877bbc1110fcf5787fffc1436d242a7cdc..cb14d2a2602808bd35106ed2bafcf7975f549597 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -285,6 +285,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { status_program_optimized_ = true; argument_.SetUseGPU(config_.use_gpu); + argument_.SetGPUDeviceId(config_.device); // Analyze inference_program if (!config_.model_dir.empty()) { argument_.SetModelDir(config_.model_dir); @@ -491,8 +492,7 @@ bool AnalysisPredictor::LoadParameters() { } // Use NaiveExecutor to Load parameters. - platform::CPUPlace place; - framework::NaiveExecutor e(place); + framework::NaiveExecutor e(place_); e.Prepare(scope_.get(), *load_program, 0, false); e.Run(); VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load"; @@ -551,4 +551,5 @@ USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(prelu); USE_TRT_CONVERTER(conv2d_transpose); +USE_TRT_CONVERTER(leaky_relu); #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 825bee833bf918067497f56adebbbcaf55f892a2..12e3a6f42e14010feedbbb5d8f8a98f60cea4556 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -116,8 +116,12 @@ class CpuPassStrategy : public PassStrategy { class GpuPassStrategy : public PassStrategy { public: GpuPassStrategy() : PassStrategy({}) { + // TODO(NHZlX) Problem with Data synchronization between GPU and CPU + // When running in GPU mode, the parameters are all on GPU. But the + // opearations of "conv_bn_fuse_pass" are on CPU. passes_.assign({ - "infer_clean_graph_pass", "conv_bn_fuse_pass", + "infer_clean_graph_pass", + // "infer_clean_graph_pass", "conv_bn_fuse_pass", }); } diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 8dd6e8453f971b980a3dcb95c10734537e8333dd..27fb41d16ead65a1ec075399bcda135e2238c7ba 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -2,7 +2,7 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc - pad_op.cc split_op.cc prelu_op.cc + pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -38,3 +38,5 @@ nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin prelu_op SERIAL) +nv_test(test_trt_leaky_relu_op SRCS test_leaky_relu_op.cc leaky_relu_op.cc + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine activation_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f6ed04c46d70b1ab68b4c01ef0c908a1a8d1a19 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +// LeakyRelu converter from fluid to tensorRT +class LeakyReluOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid leaky_relu op to tensorrt layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE(input_num == 1); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE(output_num == 1); + // Get attrs + float alpha = boost::get(op_desc.GetAttr("alpha")); + + platform::CPUPlace place; + std::unique_ptr alpha_tensor( + new framework::LoDTensor()); + alpha_tensor->Resize(framework::make_ddim({2})); + float* alpha_data = alpha_tensor->mutable_data(place); + alpha_data[0] = alpha; + alpha_data[1] = 1.f - alpha; + // the leaky relu formula y = (x > 0) ? x : alpha * x is equal to + // y = alpha * x + (x > 0) ? (1 - alpha) * x : 0 + TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT, &alpha_data[0], 1}; + TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; + TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + // y_scale = alpha * x + auto* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, shift.get(), + scale.get(), power.get()); + PADDLE_ENFORCE(nullptr != scale_layer); + // y_relu = (x > 0) : x : 0 + auto* relu_layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input, + nvinfer1::ActivationType::kRELU); + PADDLE_ENFORCE(nullptr != relu_layer); + // + TensorRTEngine::Weight sub_scale{nvinfer1::DataType::kFLOAT, &alpha_data[1], + 1}; + auto* scale_relu_layer = + TRT_ENGINE_ADD_LAYER(engine_, Scale, *(relu_layer->getOutput(0)), + nvinfer1::ScaleMode::kUNIFORM, shift.get(), + sub_scale.get(), power.get()); + PADDLE_ENFORCE(nullptr != scale_relu_layer); + auto* output_layer = + TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(scale_layer->getOutput(0)), + *(scale_relu_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kSUM); + PADDLE_ENFORCE(nullptr != output_layer); + // keep alpha tensor to avoid release it's memory + std::string alpha_name = op_desc.Output("Out")[0] + "_alpha"; + PADDLE_ENFORCE(engine_->weight_map.find(alpha_name) == + engine_->weight_map.end()); + engine_->weight_map[alpha_name] = std::move(alpha_tensor); + + std::string layer_name = "leaky_relu (Output: "; + auto output_name = op_desc.Output("Out")[0]; + output_layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, output_layer->getOutput(0)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + output_layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(leaky_relu, LeakyReluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_leaky_relu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d00826af075159004d3727a7519e7c319dbddb02 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_leaky_relu_op.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(leaky_relu_op, test_leaky_relu) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclOutputVar("leaky_relu_out", nvinfer1::DimsCHW(3, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("leaky_relu"); + desc.SetInput("X", {"leaky_relu_input"}); + desc.SetOutput("Out", {"leaky_relu_out"}); + + desc.SetAttr("alpha", 0.1f); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +// USE_OP(leaky_relu); +USE_OP(leaky_relu); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 40902694995f487343e3cfc727ad95daafa6b415..a0329325bea19bd9cdd3fcd39724cf05664b505a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,3 +1,3 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu - DEPS enforce device_context) + DEPS enforce tensorrt_engine) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 16a9b50e6fb174374d23cd021e47e52921871a8a..e8bd13037ed6c2c3c639b76f6f3561921fb6ee37 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,5 +1,9 @@ set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +if(WITH_GPU AND TENSORRT_FOUND) + set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) +endif() + function(download_model install_dir model_name) if (NOT EXISTS ${install_dir}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) @@ -27,14 +31,14 @@ function(inference_analysis_api_test_with_fake_data target install_dir filename endfunction() # RNN1 -if(NOT APPLE) +if(NOT APPLE AND WITH_MKLML) set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") download_model_and_data(${RNN1_INSTALL_DIR} "rnn1%2Fmodel.tar.gz" "rnn1%2Fdata.txt.tar.gz") inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc) else() - # TODO: fix this test on MACOS, the reason is that - # fusion_seqexpand_concat_fc_op is not supported on MACOS - message(WARNING "These tests has been disabled in OSX before being fixed: \n test_analyzer_rnn1") + # TODO: fix this test on MACOS and OPENBLAS, the reason is that + # fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS + message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_rnn1") endif() # RNN2 @@ -75,11 +79,11 @@ endif() inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) # resnet50 -inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 +inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") # mobilenet with depthwise_conv op -inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet +inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") # anakin @@ -89,15 +93,15 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1") inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin") inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt") - cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc - ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin + cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc + ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin --datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt DEPS inference_anakin_api_shared SERIAL) # anakin mobilenet if(WITH_GPU) set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet") inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin") - cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc + cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin DEPS inference_anakin_api_shared dynload_cuda SERIAL) endif() @@ -109,6 +113,6 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") endif() inference_analysis_test(test_trt_models SRCS trt_models_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL) endif() diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index e66ae2805766754d9d07877c31889dd421daf9f1..7b686045a59c93a93322f99c2cdf7050ddbf0a6d 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -222,19 +222,36 @@ void TestMultiThreadPrediction( // The inputs of each thread are all the same. std::vector outputs_tid; auto &predictor = predictors[tid]; - LOG(INFO) << "running thread " << tid; - Timer timer; - timer.tic(); - for (int i = 0; i < num_times; i++) { - for (const auto &input : inputs) { - ASSERT_TRUE(predictor->Run(input, &outputs_tid)); + + // warmup run + LOG(INFO) << "Running thread " << tid << ", warm up run..."; + { + Timer warmup_timer; + warmup_timer.tic(); + predictor->Run(inputs[0], outputs, batch_size); + PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1); +#if !defined(_WIN32) + if (FLAGS_profile) { + paddle::platform::ResetProfiler(); } +#endif } - auto time = timer.toc(); - total_time += time; - PrintTime(batch_size, num_times, num_threads, tid, time / num_times, - inputs.size()); + LOG(INFO) << "Thread " << tid << " run " << num_times << " times..."; + { + Timer timer; + timer.tic(); + for (int i = 0; i < num_times; i++) { + for (const auto &input : inputs) { + ASSERT_TRUE(predictor->Run(input, &outputs_tid)); + } + } + + auto time = timer.toc(); + total_time += time; + PrintTime(batch_size, num_times, num_threads, tid, time / num_times, + inputs.size()); + } }); } for (int i = 0; i < num_threads; ++i) { diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index 922feba10fec5d1d13b47dbce064fce2e01d8998..ef612ce6148329c33f194842945bb5438afcf645 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -145,5 +145,3 @@ TEST(TensorRT_mobilenet, analysis) { } // namespace inference } // namespace paddle - -USE_PASS(tensorrt_subgraph_pass); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..10290a4aeff6b6a023fb28961d12728aff891e83 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -0,0 +1,201 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" + +#include "paddle/fluid/platform/mkldnn_helper.h" + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include "xbyak.h" +#include "xbyak_util.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using mkldnn::memory; + +static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { + std::transform(format.begin(), format.end(), format.begin(), ::tolower); + + if (!format.compare("nchw")) { + return memory::format::nchw; + } else if (!format.compare("nchw16c")) { + return memory::format::nChw16c; + } else if (!format.compare("nchw8c")) { + return memory::format::nChw8c; + } else if (!format.compare("nhwc")) { + return memory::format::nhwc; + } else { + return memory::format::any; + } +} + +static void UpdateDataFormat(const framework::ExecutionContext& ctx, + framework::Tensor* tensor, const char* attribute) { + if (ctx.op().HasAttr(attribute)) { + auto format_as_string = ctx.Attr(attribute); + auto format = StringToMKLDNNFormat(format_as_string); + if (format != memory::format::any) { + tensor->set_format(format); + } + } +} + +template +static void ReorderInput(framework::Tensor* tensor, + const platform::Place& place, + const mkldnn::engine& engine, bool isFourDim) { + using platform::to_void_cast; + auto dims = paddle::framework::vectorize2int(tensor->dims()); + framework::Tensor out_tensor; + out_tensor.Resize(tensor->dims()); + out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); + out_tensor.set_layout(tensor->layout()); + mkldnn::memory input_memory = { + {{dims, platform::MKLDNNGetDataType(), tensor->format()}, engine}, + to_void_cast(tensor->data())}; + mkldnn::memory output_memory = { + {{dims, platform::MKLDNNGetDataType(), out_tensor.format()}, engine}, + to_void_cast(out_tensor.mutable_data(place))}; + platform::Reorder(input_memory, output_memory); + tensor->ShareDataWith(out_tensor); +} + +template +class ElementwiseMulMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + int axis = ctx.Attr("axis"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + const T* x_data = x->data(); + const T* y_data = y->data(); + T* z_data = z->mutable_data(ctx.GetPlace()); + + auto x_dims = x->dims(); + auto y_dims_untrimmed = y->dims(); + auto x_int_dims = paddle::framework::vectorize2int(x_dims); + + UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); + UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + + Xbyak::util::Cpu cpu; + const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); + const bool are_dims_divisable = !(x_int_dims[1] % 16); + const bool is_x_format_correct = x->format() == memory::format::nChw16c; + const bool is_y_format_correct = y->format() == memory::format::nc; + if (is_x_format_correct && is_y_format_correct && are_dims_divisable && + is_avx512_enabled) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; + + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); + + constexpr int simd_width = 16; + int C = c / simd_width; + + const auto& multiply = + math::jitkernel::KernelPool::Instance() + .template Get>(n); + +#pragma omp parallel for collapse(2) + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); + } + } + } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + // Fallback to naive version: + const bool are_inputs_in_same_format = x->format() == y->format(); + const bool is_x_nchw = x->format() == memory::format::nchw; + const bool is_x_nc = x->format() == memory::format::nc; + const bool is_y_nchw = y->format() == memory::format::nchw; + const bool is_y_nc = y->format() == memory::format::nc; + if (!are_inputs_in_same_format) { + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + if (!(is_x_nchw || is_x_nc)) + ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, + x->dims().size() == 4); + if (!(is_y_nchw || is_y_nc)) + ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, + y->dims().size() == 4); + } + + auto mul_func = [](T a, T b) -> T { return a * b; }; + + TransformFunctor + functor( + x, y, z, + ctx.template device_context(), + mul_func); + + axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + if (post == 1) { + functor.RunRowWise(n, pre); + } else { + functor.RunMidWise(n, pre, post); + } + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::ElementwiseMulMKLDNNKernel) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index f01f67692e1e5dd040971cb0dd1dd793648da97a..85a7817be9b3a82d40853b417d78a7fdf67f6c1f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .EqualGreaterThan(-1); AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); + AddAttr( + "x_data_format", + "(string, default NCHW) Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); + AddAttr( + "y_data_format", + "(string, default \"\") Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddComment(string::Sprintf(R"DOC( Elementwise %s Operator diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 65f83ff4846601d1575daa994772cd869d526f56..64ef55de7cf73fea4538cc0d8fa6d316ddaff2f8 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -322,6 +322,42 @@ class VActJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +#ifdef PADDLE_WITH_MKLDNN +struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { + explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) + : Xbyak::CodeGenerator(code_size) { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); + } +}; +#endif + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7e163c1349e73d8fe5e436b98c9a8f67e6439506..82d808f415c3b4ed2688d034aad13610ae2ab0f4 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -95,6 +95,15 @@ class VAddBiasKernel : public Kernel { void (*Compute)(const T *, const T *, T *, int); }; +#ifdef PADDLE_WITH_MKLDNN +template +class EltwiseMulnChw16cNCKernel : public Kernel { + public: + // nChw16c = nChw16c .* NC + void (*Compute)(const float *, const float *, float *, int, int); +}; +#endif + template class VActKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 36a50f20434f313e93bfa3dd2c9d46963024caf7..a143b51439f55d1f80d7936dfad46e31bd19f0cb 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -226,6 +226,44 @@ bool VAddKernelImpl::useMKL(int d) { } #endif +#ifdef PADDLE_WITH_MKLDNN +/* EltwiseMul for nChw16c & NC inputs JitKernel */ +template +class EltwiseMulnChw16cNCKernelImpl + : public math::jitkernel::EltwiseMulnChw16cNCKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit EltwiseMulnChw16cNCKernelImpl(int d) + : EltwiseMulnChw16cNCKernel() { + using mul_func_t = void (*)(const float*, const float*, float*, int, int); +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + // roughly estimate the size of code + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; + sz = sz > 4096 ? sz : 4096; + jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); + this->Compute = (mul_func_t)jitcode_->getCode(); + return; + } +#endif + PADDLE_THROW( + "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " + "environemnt"); + } + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +template <> +bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { + return true; +} +#endif +#endif + /* VAddRelu JitKernel */ template class VAddReluKernelImpl : public VAddReluKernel { @@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(videntity, VIdentityKernel); +#ifdef PADDLE_WITH_MKLDNN +REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); +#endif } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 02639e2177f0ef2a2f704fce724defe11cc09045..0ccef6c6a8345e31cee3ef2422fe3f56c059c231 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -38,6 +38,7 @@ std::once_flag p2p_init_flag; void InitGflags(std::vector argv) { std::call_once(gflags_init_flag, [&]() { + FLAGS_logtostderr = true; argv.insert(argv.begin(), "dummy"); int argc = argv.size(); char **arr = new char *[argv.size()]; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1513eca51439288acac35729300bcbe4e71e4205..29e4ca04a7fbb2eae870fcf15763310b849c8b53 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -45,6 +45,10 @@ if(APPLE) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext) list(REMOVE_ITEM TEST_OPS test_fuse_elewise_add_act_pass) endif() +if(NOT WITH_MKLML) + # this op is not support on openblas + list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op) +endif() function(py_test_modules TARGET_NAME) if(WITH_TESTING) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 690c4cf0ad6b2c741689e419223cfa6b6e1e5cf3..c195a28e452fbe073a9afb5d650f538176f688fd 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,7 +362,9 @@ class OpTest(unittest.TestCase): else: return [] places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): + cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False + if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\ + and not cpu_only: places.append(core.CUDAPlace(0)) return places diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..536e9a1c58ec4a8b1b5a7c1d3a5fe737b38d24ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -0,0 +1,263 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from test_elementwise_mul_op import * + + +class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = x * self.y.reshape(1, 16, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +@unittest.skip( + "Not implemented yet.") # TODO(mgallus): enable when implemented. +class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 8, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + self.y = np.random.rand(1, 8).astype(self.dtype) + + self.out = x * self.y.reshape(1, 8, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp() + self.attrs["x_data_format"] = "nchw8c" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y.reshape(1, 16, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp() + self.attrs["x_data_format"] = "nchw" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): + def init_input_output(self): + self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp() + self.attrs["x_data_format"] = "nc" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 53409e436c0739bce63a3a8f90591e0ca6836859..57ba34f833f824d13e0b82caea789f7f57622bc9 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -21,13 +21,24 @@ from paddle.fluid.op import Operator class ElementwiseMulOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + def setUp(self): self.op_type = "elementwise_mul" + self.dtype = np.float32 + self.axis = -1 + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float64"), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float64") + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} + self.outputs = {'Out': self.out} + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output() @@ -41,6 +52,17 @@ class ElementwiseMulOp(OpTest): def test_check_grad_ingore_y(self): self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + def init_dtype(self): + pass + + def init_axis(self): + pass + class TestElementwiseMulOp_scalar(ElementwiseMulOp): def setUp(self): @@ -63,17 +85,13 @@ class TestElementwiseMulOp_Vector(ElementwiseMulOp): class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): - def setUp(self): - self.op_type = "elementwise_mul" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float64), - 'Y': np.random.rand(2).astype(np.float64) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x * self.y.reshape(2, 1, 1) - self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(2, 1, 1) - } + def init_axis(self): + self.axis = 0 class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp):