diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 3348abb19b3339b2b3e8b50485133b15a1973a32..7b6ce0da07309a0ed2a5c8bcd5f59d84105261d7 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -57,6 +57,7 @@ std::unique_ptr FCFusePass::ApplyImpl( desc.SetInput("W", std::vector({fc_Y_in})); desc.SetInput("Bias", std::vector({fc_bias_in})); desc.SetOutput("Out", std::vector({fc_out_out})); + desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); desc.SetType("fc"); auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 2db7d95cae1c8c59691fd642e2462e92ed58814f..4e1e4e27f9ba932b56ecc25e816a2aee9d42362e 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -29,6 +29,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, if (type == "mul") { op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); + op->SetAttr("x_num_col_dims", {1}); } else if (type == "elementwise_add") { op->SetInput("X", inputs); } diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc b/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc index e903ec54cc4ed25ab0648c8c19caa2c8bb00b94f..b6a5dfd087c95d0ccb0f5adfa4f754cfa5a44f14 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc +++ b/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc @@ -412,7 +412,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) { void SubGraphFuser::ReplaceNodesWithSubGraphs() { auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)(); for (auto &subgraph : subgraphs) { - if (subgraph.size() <= min_subgraph_size_) continue; + if (subgraph.size() <= (size_t)min_subgraph_size_) continue; LOG(INFO) << "detect a subgraph size " << subgraph.size(); std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); // replace this sub-graph with the first node. Two steps: 1. Create a Block diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index f27347b9d176eae8fbd087a21bdedb9cb84085e6..21fd8d2df49698d7fa38d906f7921f092ca916a3 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, // it is either an OP's input or an OP's output. auto &subgraph_nodes = *Agent(node).subgraph(); - for (int index = 0; index < block_desc.OpSize(); index++) { + for (size_t index = 0; index < block_desc.OpSize(); index++) { framework::proto::OpDesc *op = block_desc.Op(index)->Proto(); auto correspond_node = subgraph_nodes[index]; PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type()); diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc index dc4d0906c4f260c8f7a11832fc52eba7191c54e8..233bfd6a42b7f123813d4ef5cecf353f7e88d208 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc @@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", - "elementwise_add", "dropout"}); + "elementwise_add", "dropout", "split"}); if (!node->IsOp()) return false; if (teller_set.count(node->Op()->Type())) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7407a1ba2f63bfe31a9d3a6f33395575c5809dee..76d205b737aeb456f242037f2b375d9c537b39f3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); +USE_TRT_CONVERTER(split); #endif diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 1e6f75e364cbe66d141cf2336f065d50928d1bc2..d67305670c91bb0814b8771332641e96974ade9d 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/inference/api/analysis_predictor.h" #include #include -#include +#include // NOLINT #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" diff --git a/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc b/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc index 6ae5198dab9a16d5a861c641dee39e4806595352..3dd1d3c838c4b1bcdefdadff16b02dbfb4a02ee9 100644 --- a/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc +++ b/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include #include //NOLINT -#include "utils.h" +#include "utils.h" // NOLINT DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_bool(use_gpu, false, "Whether use gpu."); diff --git a/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc index 72d20bc59e036afb84e2501f6af75c09be78b57e..0eb620ea516d28fb9598af8dbd297e84580a99f9 100644 --- a/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 82c04e9f3f043df9db82969e2a5ce8825a3a41f6..2ac736df7ccd54babe582ca1383903c191069d33 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -49,6 +49,8 @@ struct AnalysisConfig : public NativeConfig { void EnableTensorRtEngine(int workspace_size = 1 << 20, int max_batch_size = 1); + bool use_tensorrt() const { return use_tensorrt_; } + // NOTE this is just for internal development, please not use it. // NOT stable yet. void EnableMKLDNN(); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 8aad5c5984891546776bc52327337c94c387d6dc..80658d30850aaa7212903828c5c963da5f37ca65 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -91,7 +91,7 @@ class CpuPassStrategy : public PassStrategy { virtual ~CpuPassStrategy() = default; - virtual void EnableMKLDNN() override { + void EnableMKLDNN() override { // TODO(Superjomn) Consider the way to mix CPU with GPU. #ifdef PADDLE_WITH_MKLDNN passes_.insert(passes_.begin(), "mkldnn_placement_pass"); @@ -123,7 +123,7 @@ class GpuPassStrategy : public PassStrategy { GpuPassStrategy(const GpuPassStrategy &other) : PassStrategy(other.AllPasses()) {} - virtual void EnableMKLDNN() override; + void EnableMKLDNN() override; virtual ~GpuPassStrategy() = default; }; diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index a610687a5b11999a7cb7426dbe961e5972ee1746..e09705e3c69eb2b2370bd1ad2d9cf178ef041ee6 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) +add_subdirectory(plugin) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 0a35e10f6936313928ab21a6f17c40335e8fc882..ed4c398cee518af3211cab4e982082c46ebb36c2 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,8 +1,9 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc - DEPS tensorrt_engine operator scope framework_proto op_registry) +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc +pad_op.cc split_op.cc + DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) @@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) - nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) +nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin +split_op concat_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc index b2e7c593e85974898012f8a353817a27ca212f4d..525ba9dc341c8c1343553ac9523611f79ac3aa2d 100644 --- a/paddle/fluid/inference/tensorrt/convert/concat_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -19,7 +19,7 @@ namespace inference { namespace tensorrt { /* - * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + * ConcatOp */ class ConcatOpConverter : public OpConverter { public: diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..12179cccc76f8b0f595f41c135290dc0f3b50ad7 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * SplitOp. + */ +class SplitOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(40) << "convert a fluid split op to tensorrt split layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + int input_num = op_desc.Input("X").size(); + size_t output_num = op_desc.Output("Out").size(); + + // Get Attrs + PADDLE_ENFORCE(input_num == 1); + int axis = boost::get(op_desc.GetAttr("axis")); + std::vector output_lengths = + boost::get>(op_desc.GetAttr("sections")); + PADDLE_ENFORCE(axis != 0); + if (axis < 0) { + axis += input_dims.nbDims; + } else { + axis -= 1; + } + + PADDLE_ENFORCE(output_lengths.size() == output_num); + + // + SplitPlugin* plugin = new SplitPlugin(axis, output_lengths); + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + + std::string layer_name = "split (Output: "; + for (size_t i = 0; i < output_num; i++) { + auto output_name = op_desc.Output("Out")[i]; + layer->getOutput(i)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(i)); + layer_name += output_name; + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } + layer->setName((layer_name + ")").c_str()); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_split_op.cc b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f81d011552c152c2df79e1a272f34b954ae2a3a1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_split_op.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(split_op, test) { + std::unordered_set parameters({""}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2)); + validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2)); + validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("split"); + desc.SetInput("X", {"split_input"}); + desc.SetOutput("Out", {"split_out1", "split_out2"}); + + int num = 0; + int axis = 1; + std::vector output_lengths = {2, 1}; + desc.SetAttr("axis", axis); + desc.SetAttr("num", num); + desc.SetAttr("sections", output_lengths); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(split); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 8adc3baca64845f596477a0abe61be31e7377d9f..fdd8b56b0ce5c9b5cb6395bcb437aae5ae27829b 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() { cudaSetDevice(device_); } +nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( + nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) { + owned_plugin_.emplace_back(plugin); + return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 828181200e300c370bbfa234c3c23ae44810878c..335acdf653e55cc7f3ceccdba88992851c8e0310 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { @@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase { void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); int GetDevice() { return device_; } + nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, + int nbInputs, PluginTensorRT*); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. @@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase { std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; + // The specific GPU id that the TensorRTEngine bounded to. int device_; + std::vector> owned_plugin_; // TensorRT related internal members template diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..71b7a551619a43e5300ad3205418d1174c7019ff --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -0,0 +1 @@ +nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) diff --git a/paddle/fluid/inference/tensorrt/plugin/serialize.h b/paddle/fluid/inference/tensorrt/plugin/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..50c0b17d78327e22b0aa81fdac6958e80a30dfe8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/serialize.h @@ -0,0 +1,111 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +template +inline void SerializeValue(void** buffer, T const& value); + +template +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value); + +namespace { + +template +struct Serializer {}; + +template +struct Serializer::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(T const& value) { return sizeof(T); } + static void Serialize(void** buffer, T const& value) { + std::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void Deserialize(void const** buffer, size_t* buffer_size, T* value) { + assert(*buffer_size >= sizeof(T)); + std::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer { + static size_t SerializedSize(const char* value) { return strlen(value) + 1; } + static void Serialize(void** buffer, const char* value) { + std::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void Deserialize(void const** buffer, size_t* buffer_size, + const char** value) { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t SerializedSize(std::vector const& value) { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void Serialize(void** buffer, std::vector const& value) { + SerializeValue(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + std::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void Deserialize(void const** buffer, size_t* buffer_size, + std::vector* value) { + size_t size; + DeserializeValue(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + std::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace + +template +inline size_t SerializedSize(T const& value) { + return Serializer::SerializedSize(value); +} + +template +inline void SerializeValue(void** buffer, T const& value) { + return Serializer::Serialize(buffer, value); +} + +template +inline void DeserializeValue(void const** buffer, size_t* buffer_size, + T* value) { + return Serializer::Deserialize(buffer, buffer_size, value); +} diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd6a44dcc14d50cddb879763a93abf4297494ec9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu @@ -0,0 +1,81 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +nvinfer1::Dims SplitPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* inputDims, + int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const& input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + output_dims.d[axis_] = output_length_.at(index); + return output_dims; +} + +int SplitPlugin::initialize() { + std::vector segment_offsets(1, 0); + for (int i = 0; i < this->getNbOutputs(); ++i) { + segment_offsets.push_back(segment_offsets.back() + output_length_[i]); + } + segment_offsets_ = segment_offsets; + nvinfer1::Dims dims = this->getInputDims(0); + nx_ = 1; + for (int i = dims.nbDims - 1; i > axis_; --i) { + nx_ *= dims.d[i]; + } + ny_ = dims.d[axis_]; + nz_ = 1; + for (int i = axis_ - 1; i >= 0; --i) { + nz_ *= dims.d[i]; + } + return 0; +} + +int SplitPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, void* workspace, cudaStream_t stream) { + auto const& input_dims = this->getInputDims(0); + int input_size = 0; + float const* idata = reinterpret_cast(inputs[0]); + float** odatas = reinterpret_cast(outputs); + + // kernel impl here. + int inputBatchOffset = nx_ * ny_ * nz_; + for (size_t i = 0; i < this->getNbOutputs(); i++) { + for (size_t j = 0; j < batchSize; j++) { + cudaMemcpyAsync( + odatas[i] + + j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * + sizeof(float), + inputs[0] + + (inputBatchOffset * j + segment_offsets_[i] * nx_) * + sizeof(float), + (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + } + } + + return cudaGetLastError() != cudaSuccess; +} + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..7281e40c331550de472df49c57b1d9a5226842d5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h @@ -0,0 +1,74 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SplitPlugin : public PluginTensorRT { + int axis_; + std::vector output_length_; + int nx_, ny_, nz_; + std::vector segment_offsets_; + + protected: + virtual size_t getSerializationSize() override { + return SerializedSize(axis_) + SerializedSize(output_length_) + + getBaseSerializationSize(); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + virtual void serialize(void *buffer) override { + serializeBase(buffer); + SerializeValue(&buffer, axis_); + SerializeValue(&buffer, output_length_); + } + + public: + SplitPlugin(int axis, std::vector const &output_lengths) + : axis_(axis), output_length_(output_lengths) { + assert(axis <= nvinfer1::Dims::MAX_DIMS); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + SplitPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &axis_); + DeserializeValue(&serialData, &serialLength, &output_length_); + } + + SplitPlugin *clone() const override { + return new SplitPlugin(axis_, output_length_); + } + + virtual const char *getPluginType() const override { return "split"; } + virtual int getNbOutputs() const override { return output_length_.size(); } + virtual nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims *inputs, + int nbInputDims) override; + virtual int initialize() override; + virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // tensorrt +} // inference +} // paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000000000000000000000000000000000..08016d84b15bc750738f3183d8d61a5c90862288 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +void PluginTensorRT::serializeBase(void*& buffer) { + SerializeValue(&buffer, input_dims_); + SerializeValue(&buffer, max_batch_size_); + SerializeValue(&buffer, data_type_); + SerializeValue(&buffer, data_format_); +} + +void PluginTensorRT::deserializeBase(void const*& serialData, + size_t& serialLength) { + DeserializeValue(&serialData, &serialLength, &input_dims_); + DeserializeValue(&serialData, &serialLength, &max_batch_size_); + DeserializeValue(&serialData, &serialLength, &data_type_); + DeserializeValue(&serialData, &serialLength, &data_format_); +} + +size_t PluginTensorRT::getBaseSerializationSize() { + return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) + + SerializedSize(data_type_) + SerializedSize(data_format_)); +} + +bool PluginTensorRT::supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const { + return ((type == nvinfer1::DataType::kFLOAT) && + (format == nvinfer1::PluginFormat::kNCHW)); +} + +void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims, + int nbInputs, + const nvinfer1::Dims* outputDims, + int nbOutputs, nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) { + data_type_ = type; + data_format_ = format; + input_dims_.assign(inputDims, inputDims + nbInputs); + max_batch_size_ = maxBatchSize; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..4d85e955a49b7dcccae158ea06b76419419797cf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,80 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "NvInfer.h" + +#include "paddle/fluid/inference/tensorrt/plugin/serialize.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PluginTensorRT : public nvinfer1::IPluginExt { + public: + PluginTensorRT() {} + PluginTensorRT(const void* serialized_data, size_t length) {} + nvinfer1::Dims const& getInputDims(int index) const { + return input_dims_.at(index); + } + size_t getMaxBatchSize() const { return max_batch_size_; } + nvinfer1::DataType getDataType() const { return data_type_; } + nvinfer1::PluginFormat getDataFormat() const { return data_format_; } + virtual const char* getPluginVersion() const { return "1"; } + size_t getWorkspaceSize(int) const override { return 0; } + void terminate() override {} + virtual ~PluginTensorRT() {} + // Check format support. The default is FLOAT32 and NCHW. + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override; + void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs, + const nvinfer1::Dims* outputDims, int nbOutputs, + nvinfer1::DataType type, + nvinfer1::PluginFormat format, + int maxBatchSize) override; + + // *NOTE* The following functions need to be overrided in the subclass. + virtual nvinfer1::IPluginExt* clone() const = 0; + virtual const char* getPluginType() const = 0; + // Initialize the layer for execution. This is called when the engine is + // created. + int initialize() override { return 0; } + // Serialize the layer config to buffer. + virtual void serialize(void* buffer) = 0; + virtual size_t getSerializationSize() = 0; + virtual int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) = 0; + + protected: + // Deserialize input_dims, max_batch_size, data_type, data_format + void deserializeBase(void const*& serialData, size_t& serialLength); + size_t getBaseSerializationSize(); + // Serialize input_dims, max_batch_size, data_type, data_format + void serializeBase(void*& buffer); + + std::vector input_dims_; + size_t max_batch_size_; + nvinfer1::DataType data_type_; + nvinfer1::PluginFormat data_format_; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index fc3e44ffd741cee5185e01c254d0c591f3c179a2..74569057913d1db9a7184ab61ba655b3a66e49bd 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -45,11 +45,7 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2 # DAM set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam") download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz") -inference_analysis_test(test_analyzer_dam SRCS analyzer_dam_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS - --infer_model=${DAM_INSTALL_DIR}/model - --infer_data=${DAM_INSTALL_DIR}/data.txt - --use_analysis=0) +inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc) # chinese_ner set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner") @@ -108,8 +104,7 @@ if(WITH_GPU AND TENSORRT_FOUND) if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") endif() - inference_analysis_test(test_trt_models SRCS trt_models_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor - ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL) + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL) endif() diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index ceac5dc7e14365c77cf1cbbbc16e4bf3ebfced73..b369cba5c8b3f8aadd1123d6b7345fad6e47bd0f 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -69,7 +69,7 @@ struct DataRecord { num_lines++; std::vector data; split(line, ',', &data); - CHECK_EQ(data.size(), 2 * MAX_TURN_NUM + 3); + CHECK_EQ(data.size(), (size_t)(2 * MAX_TURN_NUM + 3)); // load turn data std::vector turns_tmp[MAX_TURN_NUM]; for (int i = 0; i < MAX_TURN_NUM; ++i) { @@ -178,7 +178,8 @@ TEST(Analyzer_dam, profile) { std::vector outputs; std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { PADDLE_ENFORCE_GT(outputs.size(), 0); @@ -196,15 +197,13 @@ TEST(Analyzer_dam, fuse_statis) { contrib::AnalysisConfig cfg; SetConfig(&cfg); - if (FLAGS_use_analysis) { - int num_ops; - auto predictor = CreatePaddlePredictor(cfg); - auto fuse_statis = GetFuseStatis( - static_cast(predictor.get()), &num_ops); - ASSERT_TRUE(fuse_statis.count("fc_fuse")); - EXPECT_EQ(fuse_statis.at("fc_fuse"), 317); - EXPECT_EQ(num_ops, 2020); - } + int num_ops; + auto predictor = CreatePaddlePredictor(cfg); + auto fuse_statis = GetFuseStatis( + static_cast(predictor.get()), &num_ops); + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + EXPECT_EQ(fuse_statis.at("fc_fuse"), 317); + EXPECT_EQ(num_ops, 2020); } // Compare result of NativeConfig and AnalysisConfig @@ -215,9 +214,8 @@ TEST(Analyzer_dam, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - if (FLAGS_use_analysis) { - CompareNativeAndAnalysis(cfg, input_slots_all); - } + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc index 5fb551810fd4d1c56547a8aa581cb6c4587df031..310852e2f7cb284bda3041911d0059b55ee3b477 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc @@ -133,7 +133,8 @@ TEST(Analyzer_LAC, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result @@ -175,7 +176,8 @@ TEST(Analyzer_LAC, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc index d91f7c314d0a936da6f5b0c41920c905af5cd0ee..3a5f844de3cae7eb9b6e3555c5219c6cf8ee1919 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc @@ -121,7 +121,8 @@ TEST(Analyzer_Chinese_ner, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result @@ -160,7 +161,8 @@ TEST(Analyzer_Chinese_ner, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 5c92096d9d3e607d79ca74e16a558a4999c44414..2b936175ed3f8ec24826485027048c82df0461ab 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -45,7 +45,8 @@ void profile(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); } TEST(Analyzer_resnet50, profile) { profile(); } @@ -74,7 +75,8 @@ void compare(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } TEST(Analyzer_resnet50, compare) { compare(); } diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 612ae121b2ecbccb0ba8faf72aef83ec01a104bd..1ae2b4b03a1b2a66b3ddc8cb66d9575751a52297 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -233,8 +233,8 @@ TEST(Analyzer_rnn1, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - LOG(INFO) << "to test prediction"; - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); } // Check the fuse status @@ -261,7 +261,8 @@ TEST(Analyzer_rnn1, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } // Test Multi-Thread. @@ -272,7 +273,8 @@ TEST(Analyzer_rnn1, multi_thread) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, 4 /* multi_thread */); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, 4 /* multi_thread */); } // Validate that the AnalysisPredictor + ZeroCopyTensor really works by testing diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc index e0eb919bd896d73a557001982a436fc93f087a74..e2985006f0ed858e778bf4737be3aaee0e056021 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc @@ -132,7 +132,8 @@ TEST(Analyzer_rnn2, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result @@ -153,7 +154,8 @@ TEST(Analyzer_rnn2, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc index f590ef27967e47ffcb3a97e80dd147efdd1906e6..858191184a377a26042c98e17d5b8df782575efc 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc @@ -161,7 +161,8 @@ TEST(Analyzer_seq_conv1, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result @@ -199,7 +200,8 @@ TEST(Analyzer_seq_conv1, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc index 05bffede472d9674aa4213468662d7e08792035b..34a241f070fdc62d1b1e94835fb1dad405baafa9 100644 --- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc @@ -74,7 +74,8 @@ TEST(Analyzer_Text_Classification, profile) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1) { // Get output @@ -101,7 +102,8 @@ TEST(Analyzer_Text_Classification, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) { @@ -112,7 +114,8 @@ TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index 8fafd25b781a1755cce3d882e92b7ed018d3686c..956a235edcefb7d688983c3b63b187e284efb02a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -59,9 +59,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->specify_input_name = true; // TODO(TJ): fix fusion gru cfg->pass_builder()->DeletePass("fc_gru_fuse_pass"); -#ifdef PADDLE_WITH_MKLDNN - cfg->EnableMKLDNN(); -#endif } void SetInput(std::vector> *inputs) { @@ -94,7 +91,8 @@ void profile(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { const float ocr_result_data[] = { @@ -136,7 +134,8 @@ void compare(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); - CompareNativeAndAnalysis(cfg, input_slots_all); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), input_slots_all); } TEST(Analyzer_vis, compare) { compare(); } diff --git a/paddle/fluid/inference/tests/api/config_printer.h b/paddle/fluid/inference/tests/api/config_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0c4b1d049bc276cda2f58ac1edd8102fb3fd88 --- /dev/null +++ b/paddle/fluid/inference/tests/api/config_printer.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +namespace paddle { +namespace inference { + +thread_local int num_spaces = 0; + +static std::string GenSpaces(int num_spaces) { + std::ostringstream os; + for (int i = 0; i < num_spaces; ++i) { + os << " "; + } + return os.str(); +} + +std::ostream &operator<<(std::ostream &os, + const PaddlePredictor::Config &config) { + os << GenSpaces(num_spaces) << "PaddlePredictor::Config {\n"; + num_spaces++; + os << GenSpaces(num_spaces) << "model_dir: " << config.model_dir << "\n"; + num_spaces--; + os << GenSpaces(num_spaces) << "}\n"; + return os; +} + +std::ostream &operator<<(std::ostream &os, const NativeConfig &config) { + os << GenSpaces(num_spaces) << "NativeConfig {\n"; + num_spaces++; + os << *reinterpret_cast(&config); + os << GenSpaces(num_spaces) << "use_gpu: " << config.use_gpu << "\n"; + os << GenSpaces(num_spaces) << "device: " << config.device << "\n"; + os << GenSpaces(num_spaces) + << "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n"; + os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n"; + os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n"; + os << GenSpaces(num_spaces) + << "specify_input_name: " << config.specify_input_name << "\n"; + num_spaces--; + os << GenSpaces(num_spaces) << "}\n"; + return os; +} + +std::ostream &operator<<(std::ostream &os, + const contrib::AnalysisConfig &config) { + os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n"; + num_spaces++; + os << *reinterpret_cast(&config); + os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.enable_ir_optim + << "\n"; + os << GenSpaces(num_spaces) + << "use_feed_fetch_ops: " << config.use_feed_fetch_ops << "\n"; + os << GenSpaces(num_spaces) << "use_tensorrt: " << config.use_tensorrt() + << "\n"; + os << GenSpaces(num_spaces) << "use_mkldnn: " << config.use_mkldnn() << "\n"; + num_spaces--; + os << GenSpaces(num_spaces) << "}\n"; + return os; +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index ab4ab20b58020e45f5002d4436d621004e4326fa..a4046914132cc713a707fc2a4d12087383d77fe5 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -19,13 +19,16 @@ #include #include // NOLINT #include + #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/api/analysis_predictor.h" -#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" + +#include "paddle/fluid/inference/api/helper.h" +#include "paddle/fluid/inference/tests/api/config_printer.h" #include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -38,10 +41,18 @@ DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DECLARE_bool(profile); + namespace paddle { namespace inference { -using contrib::AnalysisConfig; +void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { + if (use_analysis) { + LOG(INFO) << *reinterpret_cast(config); + return; + } + LOG(INFO) << *config; +} void CompareResult(const std::vector &outputs, const std::vector &ref_outputs) { @@ -77,12 +88,13 @@ void CompareResult(const std::vector &outputs, } std::unique_ptr CreateTestPredictor( - const AnalysisConfig &config, bool use_analysis = true) { + const PaddlePredictor::Config *config, bool use_analysis = true) { if (use_analysis) { - return CreatePaddlePredictor(config); - } else { - return CreatePaddlePredictor(config); + return CreatePaddlePredictor( + *(reinterpret_cast(config))); } + return CreatePaddlePredictor( + *(reinterpret_cast(config))); } size_t GetSize(const PaddleTensor &out) { return VecReduceToInt(out.shape); } @@ -111,11 +123,23 @@ std::unordered_map GetFuseStatis(PaddlePredictor *predictor, } void SetFakeImageInput(std::vector> *inputs, - const std::string &dirname) { + const std::string &dirname, bool is_combined = true, + std::string model_filename = "model", + std::string params_filename = "params") { // Set fake_image_data PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); - std::vector> feed_target_shapes = - GetFeedTargetShapes(dirname, true, "model", "params"); + std::vector> feed_target_shapes = GetFeedTargetShapes( + dirname, is_combined, model_filename, params_filename); + std::ostringstream os; + for (size_t i = 0; i < feed_target_shapes.size(); ++i) { + os << "feed target " << i << ": {" << feed_target_shapes[i][0]; + for (size_t j = 1; j < feed_target_shapes[i].size(); ++j) { + os << ", " << feed_target_shapes[i][j]; + } + os << "}\n"; + } + LOG(INFO) << os.str(); + int dim1 = feed_target_shapes[0][1]; int dim2 = feed_target_shapes[0][2]; int dim3 = feed_target_shapes[0][3]; @@ -139,25 +163,43 @@ void SetFakeImageInput(std::vector> *inputs, } void TestOneThreadPrediction( - const AnalysisConfig &config, + const PaddlePredictor::Config *config, const std::vector> &inputs, std::vector *outputs, bool use_analysis = true) { int batch_size = FLAGS_batch_size; int num_times = FLAGS_repeat; auto predictor = CreateTestPredictor(config, use_analysis); - Timer timer; - timer.tic(); - for (int i = 0; i < num_times; i++) { - for (size_t j = 0; j < inputs.size(); j++) { - predictor->Run(inputs[j], outputs); + + // warmup run + LOG(INFO) << "Warm up run..."; + { + Timer warmup_timer; + warmup_timer.tic(); + predictor->Run(inputs[0], outputs, batch_size); + PrintTime(batch_size, 1, 1, 0, warmup_timer.toc(), 1); +#if !defined(_WIN32) + if (FLAGS_profile) { + paddle::platform::ResetProfiler(); + } +#endif + } + + LOG(INFO) << "Run " << num_times << " times..."; + { + Timer run_timer; + run_timer.tic(); + for (int i = 0; i < num_times; i++) { + for (size_t j = 0; j < inputs.size(); j++) { + predictor->Run(inputs[j], outputs, batch_size); + } } + PrintTime(batch_size, num_times, 1, 0, run_timer.toc() / num_times, + inputs.size()); } - PrintTime(batch_size, num_times, 1, 0, timer.toc() / num_times, - inputs.size()); } void TestMultiThreadPrediction( - const AnalysisConfig &config, + const PaddlePredictor::Config *config, const std::vector> &inputs, std::vector *outputs, int num_threads, bool use_analysis = true) { @@ -200,12 +242,11 @@ void TestMultiThreadPrediction( } } -void TestPrediction(const AnalysisConfig &config, +void TestPrediction(const PaddlePredictor::Config *config, const std::vector> &inputs, std::vector *outputs, int num_threads, bool use_analysis = FLAGS_use_analysis) { - LOG(INFO) << "use_analysis: " << use_analysis - << ", use_mkldnn: " << config.use_mkldnn(); + PrintConfig(config, use_analysis); if (num_threads == 1) { TestOneThreadPrediction(config, inputs, outputs, use_analysis); } else { @@ -215,9 +256,9 @@ void TestPrediction(const AnalysisConfig &config, } void CompareNativeAndAnalysis( - const AnalysisConfig &config, + const PaddlePredictor::Config *config, const std::vector> &inputs) { - LOG(INFO) << "use_mkldnn: " << config.use_mkldnn(); + PrintConfig(config, true); std::vector native_outputs, analysis_outputs; TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &analysis_outputs, true); diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index 71423154f84797cf564dd4e71941853fae5a0767..922feba10fec5d1d13b47dbce064fce2e01d8998 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -1,148 +1,149 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include #include -#include "paddle/fluid/inference/analysis/analyzer.h" -#include "paddle/fluid/inference/api/helper.h" -#include "paddle/fluid/inference/api/paddle_inference_api.h" -#include "paddle/fluid/inference/api/paddle_inference_pass.h" + #include "paddle/fluid/inference/tests/api/tester_helper.h" namespace paddle { -using paddle::contrib::AnalysisConfig; - -DEFINE_string(dirname, "", "Directory of the inference model."); - -NativeConfig GetConfigNative() { - NativeConfig config; - config.model_dir = FLAGS_dirname; - // LOG(INFO) << "dirname " << config.model_dir; - config.fraction_of_gpu_memory = 0.15; - config.use_gpu = true; - config.device = 0; - return config; -} - -void PrepareTRTConfig(AnalysisConfig *config) { - config->model_dir = FLAGS_dirname + "/" + "mobilenet"; - config->fraction_of_gpu_memory = 0.15; - config->EnableTensorRtEngine(1 << 10, 5); - config->pass_builder()->DeletePass("conv_bn_fuse_pass"); - config->pass_builder()->DeletePass("fc_fuse_pass"); - config->pass_builder()->TurnOnDebug(); +namespace inference { + +DEFINE_bool(use_tensorrt, true, "Test the performance of TensorRT engine."); +DEFINE_string(prog_filename, "", "Name of model file."); +DEFINE_string(param_filename, "", "Name of parameters file."); + +template +void SetConfig(ConfigType* config, std::string model_dir, bool use_gpu, + bool use_tensorrt = false, int batch_size = -1) { + if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) { + config->prog_file = model_dir + "/" + FLAGS_prog_filename; + config->param_file = model_dir + "/" + FLAGS_param_filename; + } else { + config->model_dir = model_dir; + } + if (use_gpu) { + config->use_gpu = true; + config->device = 0; + config->fraction_of_gpu_memory = 0.15; + } } -void PrepareInputs(std::vector *tensors, int batch_size) { - PADDLE_ENFORCE_EQ(tensors->size(), 1UL); - auto &tensor = tensors->front(); - int height = 224; - int width = 224; - float *data = new float[batch_size * 3 * height * width]; - memset(data, 0, sizeof(float) * (batch_size * 3 * height * width)); - data[0] = 1.0f; - - // Prepare inputs - tensor.name = "input_0"; - tensor.shape = std::vector({batch_size, 3, height, width}); - tensor.data = PaddleBuf(static_cast(data), - sizeof(float) * (batch_size * 3 * height * width)); - tensor.dtype = PaddleDType::FLOAT32; +template <> +void SetConfig(contrib::AnalysisConfig* config, + std::string model_dir, bool use_gpu, + bool use_tensorrt, int batch_size) { + if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) { + config->prog_file = model_dir + "/" + FLAGS_prog_filename; + config->param_file = model_dir + "/" + FLAGS_param_filename; + } else { + config->model_dir = model_dir; + } + if (use_gpu) { + config->use_gpu = true; + config->device = 0; + config->fraction_of_gpu_memory = 0.15; + if (use_tensorrt) { + config->EnableTensorRtEngine(1 << 10, batch_size); + config->pass_builder()->DeletePass("conv_bn_fuse_pass"); + config->pass_builder()->DeletePass("fc_fuse_pass"); + config->pass_builder()->TurnOnDebug(); + } else { + config->enable_ir_optim = true; + } + } } -void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) { - auto config0 = GetConfigNative(); - config0.model_dir = model_dirname; - - AnalysisConfig config1(true); - PrepareTRTConfig(&config1); - config1.model_dir = model_dirname; - - auto predictor0 = CreatePaddlePredictor(config0); - auto predictor1 = CreatePaddlePredictor(config1); - - // Prepare inputs - std::vector paddle_tensor_feeds(1); - PrepareInputs(&paddle_tensor_feeds, batch_size); - - // Prepare outputs - std::vector outputs0; - std::vector outputs1; - CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0)); - CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size)); - - const size_t num_elements = outputs0.front().data.length() / sizeof(float); - const size_t num_elements1 = outputs1.front().data.length() / sizeof(float); - EXPECT_EQ(num_elements, num_elements1); - - auto *data0 = static_cast(outputs0.front().data.data()); - auto *data1 = static_cast(outputs1.front().data.data()); - - ASSERT_GT(num_elements, 0UL); - for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) { - EXPECT_NEAR(data0[i], data1[i], 1e-3); +void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) { + std::vector> inputs_all; + if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) { + SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename, + FLAGS_param_filename); + } else { + SetFakeImageInput(&inputs_all, model_dir, false, "__model__", ""); } -} -TEST(trt_models_test, mobilenet) { - CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "mobilenet"); -} -TEST(trt_models_test, resnet50) { - CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "resnet50"); -} -TEST(trt_models_test, resnext50) { - CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "resnext50"); + std::vector outputs; + if (use_analysis || use_tensorrt) { + contrib::AnalysisConfig config(true); + SetConfig(&config, model_dir, true, use_tensorrt, + FLAGS_batch_size); + TestPrediction(reinterpret_cast(&config), + inputs_all, &outputs, FLAGS_num_threads, true); + } else { + NativeConfig config; + SetConfig(&config, model_dir, true, false); + TestPrediction(reinterpret_cast(&config), + inputs_all, &outputs, FLAGS_num_threads, false); + } } -TEST(trt_models_test, raw_gpu) { - std::string model_dir = FLAGS_dirname + "/" + "mobilenet"; - auto config0 = GetConfigNative(); - config0.model_dir = model_dir; - int batch_size = 2; - - AnalysisConfig config1(true); - config1.fraction_of_gpu_memory = 0.1; - config1.enable_ir_optim = true; - config1.model_dir = model_dir; +void compare(std::string model_dir, bool use_tensorrt) { + std::vector> inputs_all; + if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) { + SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename, + FLAGS_param_filename); + } else { + SetFakeImageInput(&inputs_all, model_dir, false, "__model__", ""); + } - auto predictor0 = CreatePaddlePredictor(config0); - auto predictor1 = CreatePaddlePredictor(config1); + std::vector native_outputs; + NativeConfig native_config; + SetConfig(&native_config, model_dir, true, false, + FLAGS_batch_size); + TestOneThreadPrediction( + reinterpret_cast(&native_config), inputs_all, + &native_outputs, false); + + std::vector analysis_outputs; + contrib::AnalysisConfig analysis_config(true); + SetConfig(&analysis_config, model_dir, true, + use_tensorrt, FLAGS_batch_size); + TestOneThreadPrediction( + reinterpret_cast(&analysis_config), inputs_all, + &analysis_outputs, true); + + CompareResult(native_outputs, analysis_outputs); +} - // Prepare inputs - std::vector paddle_tensor_feeds(1); - PrepareInputs(&paddle_tensor_feeds, batch_size); +TEST(TensorRT_mobilenet, compare) { + std::string model_dir = FLAGS_infer_model + "/mobilenet"; + compare(model_dir, /* use_tensorrt */ true); +} - // Prepare outputs - std::vector outputs0; - std::vector outputs1; - CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0)); - CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size)); +TEST(TensorRT_resnet50, compare) { + std::string model_dir = FLAGS_infer_model + "/resnet50"; + compare(model_dir, /* use_tensorrt */ true); +} - const size_t num_elements = outputs0.front().data.length() / sizeof(float); - const size_t num_elements1 = outputs1.front().data.length() / sizeof(float); - EXPECT_EQ(num_elements, num_elements1); +TEST(TensorRT_resnext50, compare) { + std::string model_dir = FLAGS_infer_model + "/resnext50"; + compare(model_dir, /* use_tensorrt */ true); +} - auto *data0 = static_cast(outputs0.front().data.data()); - auto *data1 = static_cast(outputs1.front().data.data()); +TEST(TensorRT_resnext50, profile) { + std::string model_dir = FLAGS_infer_model + "/resnext50"; + profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); +} - ASSERT_GT(num_elements, 0UL); - for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) { - EXPECT_NEAR(data0[i], data1[i], 1e-3); - } +TEST(TensorRT_mobilenet, analysis) { + std::string model_dir = FLAGS_infer_model + "/" + "mobilenet"; + compare(model_dir, /* use_tensorrt */ false); } +} // namespace inference } // namespace paddle USE_PASS(tensorrt_subgraph_pass); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index fa4dec9cf118cef9b836943fd4eae90d23e6218a..e80249fc87855311479b35af61f872182292795a 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -27,11 +27,9 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "Out(Output) of Fully Connected should not be null."); PADDLE_ENFORCE(ctx->HasInput("W"), "W(Input) of Fully Connected should not be null."); - // NCHW + auto in_dims = ctx->GetInputDim("Input"); - // IO, I=C*H*W auto w_dims = ctx->GetInputDim("W"); - std::vector output_shape({in_dims[0], w_dims[1]}); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); @@ -44,14 +42,32 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "The shape of Bias must be [1, dim]."); } } - PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, - "Fully Connected input should be 2-D or 4-D tensor."); + + if (ctx->Attrs().Get("use_mkldnn")) { + PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, + "Fully Connected input should be 2-D or 4-D tensor."); + } PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Fully Connected input should be 2-D tensor."); - PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0], - "Fully Connected input and weigth size do not match."); + int in_num_col_dims = ctx->Attrs().Get("in_num_col_dims"); + PADDLE_ENFORCE_GT( + in_dims.size(), in_num_col_dims, + "The input tensor Input's rank of FCOp should be larger than " + "in_num_col_dims."); + + auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims); + PADDLE_ENFORCE_EQ( + in_mat_dims[1], w_dims[0], + "Fully Connected input and weigth size do not match. %s, %s"); + + std::vector output_dims; + output_dims.reserve(static_cast(in_num_col_dims + 1)); + for (int i = 0; i < in_num_col_dims; ++i) { + output_dims.push_back(in_dims[i]); + } + output_dims.push_back(w_dims[1]); - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("Input", "Out"); } @@ -101,12 +117,15 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( } void FCOpMaker::Make() { - AddInput("Input", - "(Tensor), The input tensor of fully connected operator with format " - "(NCHW). "); + AddInput("Input", "(Tensor), The input tensor of fully connected operator."); AddInput("W", "(Tensor), The weight fc op with shape (I, O)."); AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O") .AsDispensable(); + AddAttr("in_num_col_dims", + "(int, default 1), The fc op can take tensors with more than " + "two dimensions as its inputs.") + .SetDefault(1) + .EqualGreaterThan(1); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") @@ -131,13 +150,15 @@ class FCOpKernel : public framework::OpKernel { auto output = ctx.Output("Out"); auto in_dims = input->dims(); auto w_dims = w->dims(); + auto out_dims = output->dims(); + int M = framework::product(out_dims) / out_dims[out_dims.size() - 1]; const T* input_data = input->data(); const T* w_data = w->data(); T* output_data = output->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); math::FCCompute( - blas, in_dims[0], w_dims[1], w_dims[0], input_data, w_data, output_data, + blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data, bias ? bias->data() : NULL); // TODO(TJ): fuse act diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index b9ebe71a3d7ae270a10a45f4805652415078b363..b2c2c7954b79658e66f1524a81bcad0b7bf22c35 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -38,7 +38,7 @@ class HashOp : public framework::OperatorWithKernel { std::vector out_dims; out_dims.reserve(dims.size() + 1); // copy all dims except the last one - for (size_t i = 0u; i != dims.size() - 1; ++i) { + for (int i = 0u; i != dims.size() - 1; ++i) { out_dims.emplace_back(dims[i]); } int num_hash = ctx->Attrs().Get("num_hash"); diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 6b3eecfbd11471b5d95dcb10c91acc536f78cb85..e46f60f764ab9f1c292db339a5b38b976de5a11a 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -118,6 +118,39 @@ void VXXJitCode::generate() { ret(); } +bool ReluJitCode::init(int d) { return MayIUse(avx); } + +void ReluJitCode::generate() { + int offset = 0; + vxorps(ymm_zero, ymm_zero, ymm_zero); + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src, ptr[param1 + offset]); + vmaxps(ymm_dst, ymm_zero, ymm_src); + vmovups(ptr[param2 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovups(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovq(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 2; + rest -= 2; + } + if (rest > 0) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovss(ptr[param2 + offset], xmm_dst); + } + ret(); +} } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index aaedb0ae10323eeddfba9512d9e47c7a22320610..3c242870a24c5bb29d34d4b99406c5df8cec6763 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -85,6 +85,29 @@ class VXXJitCode : public JitCode { ymm_t ymm_zero = ymm_t(3); }; +class ReluJitCode : public JitCode { + public: + DECLARE_JIT_CODE(ReluJitCode); + explicit ReluJitCode(int d, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), num_(d) {} + static bool init(int d); + void generate() override; + + private: + int num_; + reg64_t param1{abi_param1}; + reg64_t param2{abi_param2}; + + xmm_t xmm_zero = xmm_t(0); + xmm_t xmm_src = xmm_t(1); + xmm_t xmm_dst = xmm_t(1); + + ymm_t ymm_zero = ymm_t(0); + ymm_t ymm_src = ymm_t(1); + ymm_t ymm_dst = ymm_t(1); +}; + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index e9b259282cd00cc2afc46634423ec09590bf5dd3..cd3a45e66773c89e45e80ab77ebd925abd6cbe53 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel { template class VActKernel : public Kernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VReluKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; + void (*Compute)(const T *, T *, int); }; template class VIdentityKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VExpKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VSigmoidKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template class VTanhKernel : public VActKernel { public: - virtual void Compute(const T *x, T *y) const = 0; + virtual void ComputeDeprecated(const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index c4bfbcf925a2bbdc39f8468049c58e126d3eba1b..cf46a210afbd4903dc3841f27765c390f721c763 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) { } } +template +void VReluRefer(const T* x, T* y, int n) { + for (int i = 0; i < n; ++i) { + y[i] = x[i] > 0 ? x[i] : 0; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -344,124 +351,60 @@ bool VAddBiasKernelImpl::useJIT(int d) { } #endif -#undef DECLARE_STATIC_FUNC - -REGISTER_JITKERNEL(vmul, VMulKernel); -REGISTER_JITKERNEL(vadd, VAddKernel); -REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); -REGISTER_JITKERNEL(vscal, VScalKernel); -REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); - /* VRelu JitKernel */ -template +template class VReluKernelImpl : public VReluKernel { public: - explicit VReluKernelImpl(int d) : VReluKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override { - for (int i = 0; i < this->num_; ++i) { - y[i] = x[i] > 0 ? x[i] : 0; + DECLARE_STATIC_FUNC; + explicit VReluKernelImpl(int d) : VReluKernel() { + this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 /*init*/ + + d / AVX_FLOAT_BLOCK * 4 /* instructions*/ * + 8 /*everage byte for each instruction*/; + jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); + this->Compute = jitcode_->getCode(); + return; } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \ - _mm256_storeu_ps(y, tmp); \ - } - -#define INTRI16_FLOAT(isa) \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ - } +#endif -#define INTRI_GT8LT16_FLOAT(isa) \ - template <> \ - VReluKernelImpl::VReluKernelImpl(int d) \ - : VReluKernel() { \ - this->num_ = d; \ - this->end_ = AVX_FLOAT_BLOCK; \ - this->rest_ = d - AVX_FLOAT_BLOCK; \ - } \ - template <> \ - void VReluKernelImpl::Compute(const float* x, \ - float* y) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + this->rest_, tmp1); \ + this->Compute = VReluRefer; } - -#define INTRI_GT16_FLOAT(isa) \ - template <> \ - VReluKernelImpl::VReluKernelImpl(int d) \ - : VReluKernel() { \ - this->num_ = d; \ - this->end_ = d - d % AVX_FLOAT_BLOCK; \ - this->rest_ = d - AVX_FLOAT_BLOCK; \ - } \ - template <> \ - void VReluKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 zeros = _mm256_setzero_ps(); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - tmp = _mm256_max_ps(tmp, zeros); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - __m256 tmp = _mm256_loadu_ps(x + this->rest_); \ - tmp = _mm256_max_ps(tmp, zeros); \ - _mm256_storeu_ps(y + this->rest_, tmp); \ + void ComputeDeprecated(const T* x, T* y) const override { + VReluRefer(x, y, this->num_); } +#ifdef PADDLE_WITH_XBYAK -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -INTRI_GT8LT16_FLOAT(jit::avx2); -INTRI_GT16_FLOAT(jit::avx2); + private: + std::unique_ptr jitcode_{nullptr}; #endif -#ifdef __AVX512F__ -// TODO(TJ): refine avx512 -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -INTRI_GT8LT16_FLOAT(jit::avx512f); -INTRI_GT16_FLOAT(jit::avx512f); +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VReluKernelImpl::useJIT(int d) { + return gen::ReluJitCode::init(d); +} #endif -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_GT8LT16_FLOAT -#undef INTRI_GT16_FLOAT +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); +REGISTER_JITKERNEL(vscal, VScalKernel); +REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); +REGISTER_JITKERNEL(vrelu, VReluKernel); /* An empty JitKernel */ template class VIdentityKernelImpl : public VIdentityKernel { public: explicit VIdentityKernelImpl(int d) : VIdentityKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override {} + void ComputeDeprecated(const T* x, T* y) const override {} }; -REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index c55e54a13f539014c0f582436ca1a105d0b0fedd..2ac9e1092362f60ea3d89da0c971a365b45f39ea 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -35,7 +35,7 @@ template class VExpKernelImpl : public VExpKernel { public: explicit VExpKernelImpl(int d) : VExpKernel() { this->num_ = d; } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { for (int i = 0; i < this->num_; ++i) { y[i] = std::exp(x[i]); } @@ -43,18 +43,18 @@ class VExpKernelImpl : public VExpKernel { }; #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - platform::dynload::vsExp(this->num_, x, y); \ +#define MKL_FLOAT(isa, block) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + platform::dynload::vsExp(this->num_, x, y); \ } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VExpKernelImpl::Compute(const double* x, double* y) \ - const { \ - platform::dynload::vdExp(this->num_, x, y); \ +#define MKL_DOUBLE(isa, block) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated( \ + const double* x, double* y) const { \ + platform::dynload::vdExp(this->num_, x, y); \ } FOR_EACH_ISA(MKL_FLOAT, kLT8); FOR_EACH_ISA(MKL_FLOAT, kGT8LT16); @@ -211,24 +211,24 @@ __m256 ExpAVX2(__m256 x) { } // namespace detail -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, expisa(tmp)); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + _mm256_storeu_ps(y, expisa(tmp)); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VExpKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = expisa(tmp0); \ - tmp1 = expisa(tmp1); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VExpKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + tmp0 = expisa(tmp0); \ + tmp1 = expisa(tmp1); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #ifdef __AVX__ @@ -260,14 +260,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel { this->num_ = d; vexp_ = KernelPool::Instance().template Get>(d); } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; for (int i = 0; i < this->num_; ++i) { y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = static_cast(0) - y[i]; } - vexp_->Compute(y, y); + vexp_->ComputeDeprecated(y, y); for (int i = 0; i < this->num_; ++i) { y[i] = static_cast(1) / (static_cast(1) + y[i]); } @@ -285,30 +285,30 @@ class VSigmoidKernelImpl : public VSigmoidKernel { tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::Compute(const float* x, float* y) \ - const { \ - /* TODO(TJ): try to use static const*/ \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max, expisa); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ + /* TODO(TJ): try to use static const*/ \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ - __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ - __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max, expisa); \ - INTRI_SIGMOID(tmp1, min, max, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_SIGMOID(tmp0, min, max, expisa); \ + INTRI_SIGMOID(tmp1, min, max, expisa); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #define INTRI_GT8LT16_FLOAT(isa, expisa) \ @@ -322,8 +322,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { KernelPool::Instance().template Get>(this->rest_); \ } \ template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ @@ -335,7 +335,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ y[i] = 0.f - y[i]; \ } \ - vexp_->Compute(y + this->end_, y + this->end_); \ + vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ for (int i = this->end_; i < this->num_; ++i) { \ y[i] = 1.f / (1.f + y[i]); \ } \ @@ -352,8 +352,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { KernelPool::Instance().template Get>(this->rest_); \ } \ template <> \ - void VSigmoidKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VSigmoidKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ @@ -367,7 +367,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \ y[i] = 0.f - y[i]; \ } \ - vexp_->Compute(y + this->end_, y + this->end_); \ + vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \ for (int i = this->end_; i < this->num_; ++i) { \ y[i] = 1.f / (1.f + y[i]); \ } \ @@ -408,10 +408,10 @@ class VTanhKernelImpl : public VTanhKernel { vsigmoid_ = KernelPool::Instance().template Get>(d); vaddbias_ = KernelPool::Instance().template Get>(d); } - void Compute(const T* x, T* y) const override { + void ComputeDeprecated(const T* x, T* y) const override { const T a = static_cast(2), b = static_cast(-1); vscal_->Compute(&a, x, y, this->num_); - vsigmoid_->Compute(y, y); + vsigmoid_->ComputeDeprecated(y, y); vscal_->Compute(&a, y, y, this->num_); vaddbias_->Compute(&b, y, y, this->num_); } @@ -430,25 +430,25 @@ class VTanhKernelImpl : public VTanhKernel { tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) -#define INTRI8_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y, tmp); \ +#define INTRI8_FLOAT(isa, expisa) \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp = _mm256_loadu_ps(x); \ + INTRI_VTANH(tmp, expisa); \ + _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa, expisa) \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0, expisa); \ - INTRI_VTANH(tmp1, expisa); \ - _mm256_storeu_ps(y, tmp0); \ - _mm256_storeu_ps(y + 8, tmp1); \ +#define INTRI16_FLOAT(isa, expisa) \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + __m256 tmp0 = _mm256_loadu_ps(x); \ + __m256 tmp1 = _mm256_loadu_ps(x + 8); \ + INTRI_VTANH(tmp0, expisa); \ + INTRI_VTANH(tmp1, expisa); \ + _mm256_storeu_ps(y, tmp0); \ + _mm256_storeu_ps(y + 8, tmp1); \ } #define INTRI_GT8LT16_FLOAT(isa, expisa) \ @@ -466,8 +466,8 @@ class VTanhKernelImpl : public VTanhKernel { this->rest_); \ } \ template <> \ - void VTanhKernelImpl::Compute(const float* x, \ - float* y) const { \ + void VTanhKernelImpl::ComputeDeprecated( \ + const float* x, float* y) const { \ __m256 tmp = _mm256_loadu_ps(x); \ INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ @@ -475,40 +475,40 @@ class VTanhKernelImpl : public VTanhKernel { y += AVX_FLOAT_BLOCK; \ const float a = 2.f, b = -1.f; \ vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->Compute(y, y); \ + vsigmoid_->ComputeDeprecated(y, y); \ vscal_->Compute(&a, y, y, this->num_); \ vaddbias_->Compute(&b, y, y, this->num_); \ } -#define INTRI_GT16_FLOAT(isa, expisa) \ - template <> \ - VTanhKernelImpl::VTanhKernelImpl(int d) \ - : VTanhKernel() { \ - this->num_ = d; \ - this->rest_ = d % AVX_FLOAT_BLOCK; \ - this->end_ = d - this->rest_; \ - vscal_ = \ - KernelPool::Instance().template Get>(this->rest_); \ - vsigmoid_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - vaddbias_ = KernelPool::Instance().template Get>( \ - this->rest_); \ - } \ - template <> \ - void VTanhKernelImpl::Compute(const float* x, float* y) \ - const { \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp, expisa); \ - _mm256_storeu_ps(y + i, tmp); \ - } \ - x += this->end_; \ - y += this->end_; \ - const float a = 2.f, b = -1.f; \ - vscal_->Compute(&a, x, y, this->num_); \ - vsigmoid_->Compute(y, y); \ - vscal_->Compute(&a, y, y, this->num_); \ - vaddbias_->Compute(&b, y, y, this->num_); \ +#define INTRI_GT16_FLOAT(isa, expisa) \ + template <> \ + VTanhKernelImpl::VTanhKernelImpl(int d) \ + : VTanhKernel() { \ + this->num_ = d; \ + this->rest_ = d % AVX_FLOAT_BLOCK; \ + this->end_ = d - this->rest_; \ + vscal_ = \ + KernelPool::Instance().template Get>(this->rest_); \ + vsigmoid_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + vaddbias_ = KernelPool::Instance().template Get>( \ + this->rest_); \ + } \ + template <> \ + void VTanhKernelImpl::ComputeDeprecated(const float* x, \ + float* y) const { \ + for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ + __m256 tmp = _mm256_loadu_ps(x + i); \ + INTRI_VTANH(tmp, expisa); \ + _mm256_storeu_ps(y + i, tmp); \ + } \ + x += this->end_; \ + y += this->end_; \ + const float a = 2.f, b = -1.f; \ + vscal_->Compute(&a, x, y, this->num_); \ + vsigmoid_->ComputeDeprecated(y, y); \ + vscal_->Compute(&a, y, y, this->num_); \ + vaddbias_->Compute(&b, y, y, this->num_); \ } #ifdef __AVX__ diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index ba3e917377cf12192a068a9d71238442e12d5e5e..926221f0a75c461e275a72f16b4339ae28a8e988 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel { void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, T* checked) const override { // gates: W_ch, W_ih, W_fh, W_oh - act_gate_d3_->Compute(gates + d_, gates + d_); + act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated */ - act_cand_d_->Compute(gates, gates); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_); - act_cand_d_->Compute(gates, gates); + act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, ct, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_); - act_cell_d_->Compute(ct, gates + d2_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); - act_gate_d2_->Compute(gates + d_, gates + d_); + act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ - act_cand_d_->Compute(gates, gates); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); - act_gate_d_->Compute(gates + d3_, gates + d3_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); /* H_t = act_cell(C_t) * ogated */ - act_cell_d_->Compute(ct, gates + d2_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override { /* C_t = igated * cgated*/ - act_gate_d_->Compute(gates + d_, gates + d_); - act_cand_d_->Compute(gates, gates); + act_gate_d_->ComputeDeprecated(gates + d_, gates + d_); + act_cand_d_->ComputeDeprecated(gates, gates); vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ - act_gate_d_->Compute(gates + d3_, gates + d3_); - act_cell_d_->Compute(ct, gates + d2_); + act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_); + act_cell_d_->ComputeDeprecated(ct, gates + d2_); vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_); } @@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel { } void ComputeH1(T* gates, T* ht) const override { - act_gate_d_->Compute(gates, gates); - act_state_d_->Compute(gates + d2_, gates + d2_); + act_gate_d_->ComputeDeprecated(gates, gates); + act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_); vmul_d_->Compute(gates, gates + d2_, ht, d_); } void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { // W: {W_update, W_reset; W_state} - act_gate_d2_->Compute(gates, gates); + act_gate_d2_->ComputeDeprecated(gates, gates); vmul_d_->Compute(ht_1, gates + d_, ht, d_); } void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { T* y = gates + d2_; - act_state_d_->Compute(y, y); + act_state_d_->ComputeDeprecated(y, y); // out = zt*ht~ + (1-zt)*ht_1 for (int i = 0; i < d_; ++i) { ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 7dc3e600b564d95b46070ff4436b2d0de2f3e105..5e1f91ffae03796be2817d0461900c2512938c77 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) { #endif auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->Compute(x_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); VLOG(30) << "Vec size " << d @@ -181,7 +181,7 @@ TEST(JitKernel, vexp) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -222,7 +222,7 @@ void vsigmoid_better( y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]); y[i] = 0.f - y[i]; } - vexp->Compute(y, y); + vexp->ComputeDeprecated(y, y); for (int i = 0; i < n; ++i) { y[i] = 1.f / (1.f + y[i]); } @@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -287,7 +287,7 @@ void vtanh_better( const int n, const float* x, float* y) { const float a = 2.f, b = -1.f; vscal->Compute(&a, x, y, n); - vsigmoid->Compute(y, y); + vsigmoid->ComputeDeprecated(y, y); vscal->Compute(&a, y, y, n); vaddbias->Compute(&b, y, y, n); } @@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) { auto trefe = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, ztgt_data); + ker->ComputeDeprecated(x_data, ztgt_data); } auto ttgte = GetCurrentUS(); @@ -344,8 +344,8 @@ void lstm_ctht_ref( const std::shared_ptr< const paddle::operators::math::jitkernel::VExpKernel>& vexp_1, const int d, float* gates, const float* ct_1, float* ct, float* ht) { - vsigmoid_3d->Compute(gates + d, gates + d); - vtanh_d->Compute(gates, gates); + vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); + vtanh_d->ComputeDeprecated(gates, gates); const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3; const float min = SIGMOID_THRESHOLD_MIN; const float max = SIGMOID_THRESHOLD_MAX; @@ -355,7 +355,7 @@ void lstm_ctht_ref( // H_t = act_cell(C_t) * ogated float tmp = ct[k] * 2; tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vexp_1->Compute(&tmp, &tmp); + vexp_1->ComputeDeprecated(&tmp, &tmp); tmp = 2.f / (1.f + tmp) - 1.f; ht[k] = tmp * o[k]; } @@ -373,13 +373,13 @@ void lstm_ctht_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd_d, const int d, float* gates, const float* ct_1, float* ct, float* ht) { int d2 = d * 2; - vsigmoid_3d->Compute(gates + d, gates + d); - vtanh_d->Compute(gates, gates); + vsigmoid_3d->ComputeDeprecated(gates + d, gates + d); + vtanh_d->ComputeDeprecated(gates, gates); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ - vtanh_d->Compute(ct, gates + d2); + vtanh_d->ComputeDeprecated(ct, gates + d2); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); } @@ -736,7 +736,7 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VReluKernel>& vrelu, const float* x, const float* y, float* z, int d) { vadd->Compute(x, y, z, d); - vrelu->Compute(z, z); + vrelu->ComputeDeprecated(z, z); } TEST(JitKernel, vaddrelu) { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 9577a4cb9d275df9604b7578f8685e4d2938a5e9..5978c1d6056001142854583840b8bfcb54d475d1 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -244,7 +244,7 @@ typename std::enable_if< std::is_same::value>::type elementwise_add_to(const DeviceContext& ctx, BlasT* blas, size_t data_len, const T* in, T* out) { - for (int64_t i = 0; i < data_len; i++) { + for (size_t i = 0; i < data_len; i++) { out[i] += in[i]; } } diff --git a/paddle/fluid/operators/math/sequence_pooling_test.cc b/paddle/fluid/operators/math/sequence_pooling_test.cc index 2bc008dd34ffcfe93a00bd4a8cde61626d91e235..5535523e798912ff80eeb5d753914c7d8d70a05f 100644 --- a/paddle/fluid/operators/math/sequence_pooling_test.cc +++ b/paddle/fluid/operators/math/sequence_pooling_test.cc @@ -70,11 +70,11 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) { EXPECT_EQ(in_grad.lod(), lod); if (paddle::platform::is_cpu_place(*place)) { - for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) { + for (size_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) { int64_t begin = in_grad.lod()[0][i]; int64_t end = in_grad.lod()[0][i + 1]; paddle::framework::Tensor tmp = in_grad.Slice(begin, end); - for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) { + for (size_t j = 0; j != tmp.numel() / second_dim; ++j) { for (int64_t m = 0; m != second_dim; ++m) { EXPECT_EQ(tmp.data()[m + j * second_dim], out_grad.data()[m + i * second_dim]); @@ -82,11 +82,11 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) { } } } else { - for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) { + for (size_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) { int64_t begin = cpu_in_grad.lod()[0][i]; int64_t end = cpu_in_grad.lod()[0][i + 1]; paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end); - for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) { + for (size_t j = 0; j != tmp.numel() / second_dim; ++j) { for (int64_t m = 0; m != second_dim; ++m) { EXPECT_EQ(tmp.data()[m + j * second_dim], cpu_out_grad.data()[m + i * second_dim]); diff --git a/paddle/fluid/operators/math/softmax.cc b/paddle/fluid/operators/math/softmax.cc index 78c65af24a8c5fa57e33415acc3018790bf70790..fa2018178f44ff4e3b14937c1f508fa8a698e20e 100644 --- a/paddle/fluid/operators/math/softmax.cc +++ b/paddle/fluid/operators/math/softmax.cc @@ -19,8 +19,10 @@ namespace paddle { namespace operators { namespace math { -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index ce183ed3649055aab31eb6e3f44f2224475957e9..2e9669049e36478549b793e3fa76220825888e21 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -98,9 +98,14 @@ template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; template class SoftmaxGradCUDNNFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; -template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor; template class SoftmaxGradFunctor +template class SoftmaxFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor* X, diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index dd9971ba091cc3ece86654f65c335b98087f45ed..7cf98f27251db3cfe5e8e295ed21056f6e5a2963 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -32,10 +32,10 @@ struct ValueClip { } }; -template -void SoftmaxFunctor::operator()(const DeviceContext& context, - const framework::Tensor* X, - framework::Tensor* Y) { +template +void SoftmaxFunctor::operator()( + const DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -65,6 +65,39 @@ void SoftmaxFunctor::operator()(const DeviceContext& context, .broadcast(one_by_class)); } +template +class SoftmaxFunctor { + void operator()(const DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + + softmax.device(*context.eigen_device()) = shifted_logits.exp(); + softmax.device(*context.eigen_device()) = (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } +}; + template void SoftmaxGradFunctor::operator()( const DeviceContext& context, const framework::Tensor* y, diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index fef9e023d02f45e21ec409ad398ba7d9bdd36880..99c57590191d58a12760fb335df76037685d1ced 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -43,11 +43,11 @@ class MergeIdsOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(ids.size(), outs.size(), "the number of Ids and Out should be the same"); - int row_ids_size = 0; + size_t row_ids_size = 0; int row_size = 0; int embedding_size = 0; - for (int i = 0; i < x_tensors.size(); ++i) { + for (size_t i = 0; i < x_tensors.size(); ++i) { const auto *x_tensor = x_tensors[i]; const auto *row_id = row_ids[i]; @@ -66,7 +66,7 @@ class MergeIdsOpKernel : public framework::OpKernel { std::unordered_map> selected_rows_idx_map; - for (int i = 0; i < x_tensors.size(); ++i) { + for (size_t i = 0; i < x_tensors.size(); ++i) { const auto *row_id = row_ids[i]; for (int j = 0; j < row_id->numel(); ++j) { @@ -78,7 +78,7 @@ class MergeIdsOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(), "the rows and tensor map size should be the same"); - for (int i = 0; i < outs.size(); ++i) { + for (size_t i = 0; i < outs.size(); ++i) { auto *out_ids = ids[i]; auto *out = outs[i]; diff --git a/paddle/fluid/operators/ref_by_trainer_id_op.h b/paddle/fluid/operators/ref_by_trainer_id_op.h index 2ce577544ae2437b9297da2190fd09b435d5173c..34192278d84758d720e021215c14a54349ba0c62 100644 --- a/paddle/fluid/operators/ref_by_trainer_id_op.h +++ b/paddle/fluid/operators/ref_by_trainer_id_op.h @@ -38,7 +38,7 @@ class RefByTrainerIdKernel : public framework::OpKernel { } else { trainer_id = *trainer_id_data; } - PADDLE_ENFORCE_LT(trainer_id, in_list.size()); + PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size()); out->mutable_data(context.GetPlace()); out->ShareDataWith(*(in_list[trainer_id])); } diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index cf1eeb017d666f605a431aa54637d8cbc99c7c46..2fea8a65bc5141b11549ef400f11b54278be35f9 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -35,8 +35,13 @@ class SoftmaxKernel : public framework::OpKernel { Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - math::SoftmaxFunctor()( +#ifdef ON_INFER + math::SoftmaxFunctor()( context.template device_context(), &X_2d, &Out_2d); +#else + math::SoftmaxFunctor()( + context.template device_context(), &X_2d, &Out_2d); +#endif } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index e9aba3b37b8cc01d4fe5de5200579d4e93f67e56..c0530e3d8bc407ddd6d7bf6e10a715185d0beb1f 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); - math::SoftmaxFunctor()(dev_ctx, logits, - softmax); + math::SoftmaxFunctor()( + dev_ctx, logits, softmax); math::CrossEntropyFunctor()( dev_ctx, loss, softmax, labels, context.Attr("soft_label"), context.Attr("ignore_index")); diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index 6dbada3da8826f0e7cb07a9642d327e5ee38c309..f5d6d85d7d75507f82de212812ecee0a650d3aad 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -64,7 +64,7 @@ class SplitIdsOpKernel : public framework::OpKernel { out_ids.resize(outs.size()); // split id by their shard_num. - for (int i = 0; i < all_ids.size(); ++i) { + for (size_t i = 0; i < all_ids.size(); ++i) { T id = all_ids[i]; size_t shard_id = static_cast(id) % shard_num; out_ids[shard_id].push_back(id); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 9a375d37e66332a55b00516e8476b0fe446402a2..81299189160739e54f39348ad327ff2edd2ac0e0 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -112,11 +112,11 @@ def __bootstrap__(): os.environ['OMP_NUM_THREADS'] = str(num_threads) read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', - 'eager_delete_scope', 'use_mkldnn', 'use_ngraph', - 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', - 'paddle_num_threads', 'dist_threadpool_size', - 'eager_delete_tensor_gb', 'reader_queue_speed_test_mode' + 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope', + 'use_mkldnn', 'use_ngraph', 'initial_cpu_memory_in_mb', + 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', + 'dist_threadpool_size', 'eager_delete_tensor_gb', + 'reader_queue_speed_test_mode' ] if os.name != 'nt': read_env_flags.append('warpctc_dir') diff --git a/python/paddle/fluid/tests/unittests/dist_save_load.py b/python/paddle/fluid/tests/unittests/dist_save_load.py index edc60550058f53da456c21de4b41142b907743df..cf62817956c12cd4487eba88bf49ed43331dff03 100644 --- a/python/paddle/fluid/tests/unittests/dist_save_load.py +++ b/python/paddle/fluid/tests/unittests/dist_save_load.py @@ -26,6 +26,7 @@ from multiprocessing import Process from functools import reduce import numpy as np +import pickle import unittest import six @@ -166,7 +167,10 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): io.save_persistables(startup_exe, model_dir, trainer_prog) var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor()) - print(np.ravel(var).tolist()) + if six.PY2: + print(pickle.dumps(np.ravel(var).tolist())) + else: + sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist())) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_save_load.py b/python/paddle/fluid/tests/unittests/test_dist_save_load.py index 03066fee48b703f8b55bd4ae6a9c4bb8deecab1e..ea2b554dac83988955e3a7e8919e57a4ed7a8215 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_dist_save_load.py @@ -65,14 +65,14 @@ class TestDistSaveLoadDense2x2(TestDistBase): shutil.rmtree(model_dir) - local_np = np.array(eval(local_var[0])) - train0_np = np.array(eval(tr0_var[0])) - train1_np = np.array(eval(tr1_var[0])) + local_np = np.array(local_var) + train0_np = np.array(tr0_var) + train1_np = np.array(tr1_var) + self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta) self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta) self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta) - @unittest.skip(reason="CI fail") def test_dist(self): need_envs = { "IS_DISTRIBUTED": '0', diff --git a/python/requirements.txt b/python/requirements.txt index 7a24dd519afb0d0bc84e89da9f6b2ab2fa8718ce..84cf440397b994ba12fa70d9e316e788f34e2415 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,7 +1,7 @@ requests==2.9.2 numpy>=1.12,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version protobuf==3.1 -recordio>=0.1.0; sys_platform != 'win32' +recordio>=0.1.0 matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib rarfile scipy>=0.19.0