diff --git a/CMakeLists.txt b/CMakeLists.txt index c649aafeddaf9f28c213d086236c3779d3137d92..de47086dbd6a440cd413c7843c83b1c69d9841b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) +option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) @@ -181,6 +182,11 @@ if(WITH_GPU) include(cuda) endif(WITH_GPU) +# TensorRT depends on GPU. +if (NOT WITH_GPU) + set(WITH_TENSORRT OFF) +endif() + if(WITH_AMD_GPU) find_package(HIP) include(hip) diff --git a/Dockerfile b/Dockerfile index 7856d3bbc492af4cad2d6b9f49001c90eadbea43..9097bb657d2366997112ec7662762a93358aa647 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,6 +45,13 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin # install glide RUN curl -s -q https://glide.sh/get | sh +# Install TensorRT +# The unnecessary files has been removed to make the library small. +RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \ + tar -xz -C /usr/local && \ + cp -rf /usr/local/TensorRT/include /usr && \ + cp -rf /usr/local/TensorRT/lib /usr + # git credential to skip password typing RUN git config --global credential.helper store @@ -57,7 +64,7 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 # specify sphinx version as 1.5.6 and remove -U option for [pip install -U # sphinx-rtd-theme] since -U option will cause sphinx being updated to newest # version(1.7.1 for now), which causes building documentation failed. -RUN pip install --upgrade pip && \ +RUN pip install --upgrade pip==9.0.3 && \ pip install -U wheel && \ pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install sphinx-rtd-theme==0.1.9 recommonmark diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index aa249159470773241e0f6da2e8e086264634dd4a..e90948782bb5e333bbdb47ef9d61c1e37e3cf9e4 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -33,7 +33,7 @@ ExternalProject_Add( extern_grpc DEPENDS protobuf zlib GIT_REPOSITORY "https://github.com/grpc/grpc.git" - GIT_TAG "v1.11.x" + GIT_TAG "v1.10.x" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst index 22e6fb13d7320986a60bc1ef5530187e0970c767..5c02886efd7d11e9520910526fb90ec01e123bae 100644 --- a/doc/fluid/api/layers.rst +++ b/doc/fluid/api/layers.rst @@ -473,6 +473,12 @@ multiplex .. autofunction:: paddle.fluid.layers.multiplex :noindex: +label_smooth +------------ + +.. autofunction:: paddle.fluid.layers.label_smooth + :noindex: + ops === diff --git a/doc/fluid/dev/index_cn.rst b/doc/fluid/dev/index_cn.rst index b123b756e2251c38f319e1aefa2cb04fd7a36b03..ad798003f560e7fb0e6db6083fdd152fd3417584 100644 --- a/doc/fluid/dev/index_cn.rst +++ b/doc/fluid/dev/index_cn.rst @@ -4,6 +4,7 @@ .. toctree:: :maxdepth: 1 + api_doc_std_cn.md new_op_cn.md new_op_kernel.md use_eigen_cn.md diff --git a/doc/fluid/dev/index_en.rst b/doc/fluid/dev/index_en.rst index 98988fc22dcedecdbcd67fb3bf761377bf046337..80c899a82fa452c5cd8f38dad89c15d3041b09e3 100644 --- a/doc/fluid/dev/index_en.rst +++ b/doc/fluid/dev/index_en.rst @@ -4,6 +4,7 @@ Development .. toctree:: :maxdepth: 1 + api_doc_std_en.md new_op_en.md new_op_kernel.md use_eigen_en.md diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index e0dd9e6068174a4b0348d503f4082bee6ff68dac..5a95cbc53625888bac539f91af391ff0babec17b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, + const OpDesc &op, const platform::Place &p, const size_t &i) const { auto *op_handle = result->ops_.back().get(); - op_handle->dev_ctxes_[p] = const_cast( - platform::DeviceContextPool::Instance().Get(p)); + op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p); - auto var_names = op->InputArgumentNames(); + auto var_names = op.InputArgumentNames(); for (auto &each_var_name : var_names) { VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } - var_names = op->OutputArgumentNames(); + var_names = op.OutputArgumentNames(); for (auto &each_var_name : var_names) { CreateOpOutput(result, op_handle, each_var_name, p, i); @@ -107,7 +107,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new SendOpHandle(*op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(&result, op, p, 0); + CreateOpHandleIOs(&result, *op, p, 0); continue; } @@ -117,7 +117,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); auto *op_handle = result.ops_.back().get(); - CreateOpHandleIOs(&result, op, p, i); + CreateOpHandleIOs(&result, *op, p, i); auto var_names = op->OutputArgumentNames(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index de34caab1be85eecb741a5003f026eb982e178ea..f1518d75b421006db6311c3b0f602e47000ab381 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::unique_ptr Build(const ProgramDesc &program) const override; private: - void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, - const size_t &i) const; + void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, + const platform::Place &p, const size_t &i) const; private: std::string loss_var_name_; diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 66618a291b59996836e822587af618927a4263c7..6c46e9aad5b7fbf67fdcc07a12e7932ac8b6412b 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -66,7 +66,7 @@ TEST(ProgramDesc, copy_ctor) { for (size_t i = 0; i < global_block->OpSize(); ++i) { auto op_origin = global_block->Op(i); - auto op_copy = global_block->Op(i); + auto op_copy = global_block_copy->Op(i); ASSERT_EQ(op_origin->Type(), op_copy->Type()); ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs()); @@ -131,7 +131,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { for (size_t i = 0; i < global_block->OpSize(); ++i) { auto op_origin = global_block->Op(i); - auto op_restored = global_block->Op(i); + auto op_restored = global_block_restored->Op(i); ASSERT_EQ(op_origin->Type(), op_restored->Type()); ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs()); diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index e53bcf2384e54e21c7dd5638f3b7469a35b571bf..8494edee6c2c714c285c45bbb4fe1d8cb1a524aa 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -21,4 +21,7 @@ endif() if(WITH_TESTING) add_subdirectory(tests/book) + if (WITH_TENSORRT) + add_subdirectory(tensorrt) + endif() endif() diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e39c0daac76e0993382868289f66351da3d16f8f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -0,0 +1 @@ +nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..a81a708e7a79225fd52c4b8e081afdcd8fe7e9ad --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "NvInfer.h" +#include "cuda.h" +#include "cuda_runtime_api.h" +#include "paddle/fluid/platform/dynload/tensorrt.h" + +namespace dy = paddle::platform::dynload; + +class Logger : public nvinfer1::ILogger { + public: + void log(nvinfer1::ILogger::Severity severity, const char* msg) override { + switch (severity) { + case Severity::kINFO: + LOG(INFO) << msg; + break; + case Severity::kWARNING: + LOG(WARNING) << msg; + break; + case Severity::kINTERNAL_ERROR: + case Severity::kERROR: + LOG(ERROR) << msg; + break; + default: + break; + } + } +}; + +class ScopedWeights { + public: + ScopedWeights(float value) : value_(value) { + w.type = nvinfer1::DataType::kFLOAT; + w.values = &value_; + w.count = 1; + } + const nvinfer1::Weights& get() { return w; } + + private: + float value_; + nvinfer1::Weights w; +}; + +// The following two API are implemented in TensorRT's header file, cannot load +// from the dynamic library. So create our own implementation and directly +// trigger the method from the dynamic library. +nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} +nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { + return static_cast( + dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); +} + +const char* kInputTensor = "input"; +const char* kOutputTensor = "output"; + +// Creates a network to compute y = 2x + 3 +nvinfer1::IHostMemory* CreateNetwork() { + Logger logger; + // Create the engine. + nvinfer1::IBuilder* builder = createInferBuilder(logger); + ScopedWeights weights(2.); + ScopedWeights bias(3.); + + nvinfer1::INetworkDefinition* network = builder->createNetwork(); + // Add the input + auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT, + nvinfer1::DimsCHW{1, 1, 1}); + EXPECT_NE(input, nullptr); + // Add the hidden layer. + auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get()); + EXPECT_NE(layer, nullptr); + // Mark the output. + auto output = layer->getOutput(0); + output->setName(kOutputTensor); + network->markOutput(*output); + // Build the engine. + builder->setMaxBatchSize(1); + builder->setMaxWorkspaceSize(1 << 10); + auto engine = builder->buildCudaEngine(*network); + EXPECT_NE(engine, nullptr); + // Serialize the engine to create a model, then close. + nvinfer1::IHostMemory* model = engine->serialize(); + network->destroy(); + engine->destroy(); + builder->destroy(); + return model; +} + +void Execute(nvinfer1::IExecutionContext& context, const float* input, + float* output) { + const nvinfer1::ICudaEngine& engine = context.getEngine(); + // Two binds, input and output + ASSERT_EQ(engine.getNbBindings(), 2); + const int input_index = engine.getBindingIndex(kInputTensor); + const int output_index = engine.getBindingIndex(kOutputTensor); + // Create GPU buffers and a stream + void* buffers[2]; + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float))); + ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float))); + cudaStream_t stream; + ASSERT_EQ(0, cudaStreamCreate(&stream)); + // Copy the input to the GPU, execute the network, and copy the output back. + ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), + cudaMemcpyHostToDevice, stream)); + context.enqueue(1, buffers, stream, nullptr); + ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), + cudaMemcpyDeviceToHost, stream)); + cudaStreamSynchronize(stream); + + // Release the stream and the buffers + cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaFree(buffers[input_index])); + ASSERT_EQ(0, cudaFree(buffers[output_index])); +} + +TEST(TensorrtTest, BasicFunction) { + // Create the network serialized model. + nvinfer1::IHostMemory* model = CreateNetwork(); + + // Use the model to create an engine and an execution context. + Logger logger; + nvinfer1::IRuntime* runtime = createInferRuntime(logger); + nvinfer1::ICudaEngine* engine = + runtime->deserializeCudaEngine(model->data(), model->size(), nullptr); + model->destroy(); + nvinfer1::IExecutionContext* context = engine->createExecutionContext(); + + // Execute the network. + float input = 1234; + float output; + Execute(*context, &input, &output); + EXPECT_EQ(output, input * 2 + 3); + + // Destroy the engine. + context->destroy(); + engine->destroy(); + runtime->destroy(); +} diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 84dac2937de02b3374156ebc83e19dac9f9a3e7a..b93b925a72a55442c105e4280a3580f4ea5b93a1 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,6 +1,11 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc) +if (WITH_TENSORRT) + list(APPEND CUDA_SRCS tensorrt.cc) +endif() + + configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if (CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 3c1ccc7445ed27c711ab250aa223c66ae0da45dc..19c01dc5a968c7e1d2b0f15cf9a0e8427004e58b 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -45,6 +45,10 @@ DEFINE_string(nccl_dir, "", DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so."); +DEFINE_string( + tensorrt_dir, "", + "Specify path for loading tensorrt library, such as libnvinfer.so."); + namespace paddle { namespace platform { namespace dynload { @@ -194,6 +198,14 @@ void* GetNCCLDsoHandle() { #endif } +void* GetTensorRtDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); +#else + return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so"); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 4c85093a43e0e8d75b64c5b29d1ec68db1b44909..0de3559b6088086cb52c254535b6ec42da7dd724 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -25,6 +25,7 @@ void* GetCurandDsoHandle(); void* GetWarpCTCDsoHandle(); void* GetLapackDsoHandle(); void* GetNCCLDsoHandle(); +void* GetTensorRtDsoHandle(); } // namespace dynload } // namespace platform diff --git a/paddle/fluid/platform/dynload/tensorrt.cc b/paddle/fluid/platform/dynload/tensorrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3c8e27944ca9b6419de87d752df3a83751039b1 --- /dev/null +++ b/paddle/fluid/platform/dynload/tensorrt.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/platform/dynload/tensorrt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag tensorrt_dso_flag; +void *tensorrt_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h new file mode 100644 index 0000000000000000000000000000000000000000..f584a49da0fefe0b064b5fb55b01ec132225ce5e --- /dev/null +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag tensorrt_dso_flag; +extern void* tensorrt_dso_handle; + +#ifdef PADDLE_USE_DSO + +#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using tensorrt_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(tensorrt_dso_flag, []() { \ + tensorrt_dso_handle = \ + paddle::platform::dynload::GetTensorRtDsoHandle(); \ + PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \ + }); \ + void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ + PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#else +#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + tensorrtResult_t operator()(Args... args) { \ + return __name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#endif + +#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ + __macro(createInferBuilder_INTERNAL); \ + __macro(createInferRuntime_INTERNAL); + +TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/utils/DynamicLoader.cpp b/paddle/utils/DynamicLoader.cpp index 5604a90038b06d2c1a4d9db70e4185cddfd25d3e..9ac4a56c6e300d299467630b39a32567af72cf40 100644 --- a/paddle/utils/DynamicLoader.cpp +++ b/paddle/utils/DynamicLoader.cpp @@ -32,6 +32,8 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +DEFINE_string(tensorrt_dir, "", "Specify path for loading libnvinfer.so."); + static inline std::string join(const std::string& part1, const std::string& part2) { // directory separator @@ -157,3 +159,12 @@ void GetLapackDsoHandle(void** dso_handle) { GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); #endif } + +void GetTensorRtDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath( + FLAGS_tensorrt_dir, "libnvinfer.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so", dso_handle); +#endif +} diff --git a/paddle/utils/DynamicLoader.h b/paddle/utils/DynamicLoader.h index 2e5ff76a06152b6a12818f06baaeaa6a69726ba8..02f519de4b3988fb6aca323aaa1751ee2c4bd738 100644 --- a/paddle/utils/DynamicLoader.h +++ b/paddle/utils/DynamicLoader.h @@ -58,3 +58,11 @@ void GetWarpCTCDsoHandle(void** dso_handle); * */ void GetLapackDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of tensorrt + * + * @param **dso_handle dso handler + * + */ +void GetTensorRtDsoHandle(void** dso_handle); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5c2c2dd7abebf8960d68b4c4dfd746a4e27acd03..bba8b64bd88c3edc6eda110dde38c0ced50439f6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -77,6 +77,7 @@ __all__ = [ 'lod_reset', 'lrn', 'pad', + 'label_smooth', ] @@ -3678,3 +3679,68 @@ def pad(x, paddings, pad_value=0., name=None): attrs={'paddings': paddings, 'pad_value': float(pad_value)}) return out + + +def label_smooth(label, + prior_dist=None, + epsilon=0.1, + dtype="float32", + name=None): + """ + Label smoothing is a mechanism to regularize the classifier layer and is + called label-smoothing regularization (LSR). + + Label smoothing is proposed to encourage the model to be less confident, + since optimizing the log-likelihood of the correct label directly may + cause overfitting and reduce the ability of the model to adapt. Label + smoothing replaces the ground-truth label :math:`y` with the weighted sum + of itself and some fixed distribution :math:`\mu`. For class :math:`k`, + i.e. + + .. math:: + + \\tilde{y_k} = (1 - \epsilon) * y_k + \epsilon * \mu_k, + + where :math:`1 - \epsilon` and :math:`\epsilon` are the weights + respectively, and :math:`\\tilde{y}_k` is the smoothed label. Usually + uniform distribution is used for :math:`\mu`. + + See more details about label smoothing in https://arxiv.org/abs/1512.00567. + + Args: + label(Variable): The input variable containing the label data. The + label data should use one-hot representation. + prior_dist(Variable): The prior distribution to be used to smooth + labels. If not provided, an uniform distribution + is used. The shape of :attr:`prior_dist` should + be :math:`(1, class\_num)`. + epsilon(float): The weight used to mix up the original ground-truth + distribution and the fixed distribution. + dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, + float_64, int etc. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The tensor variable containing the smoothed labels. + + Examples: + .. code-block:: python + + label = layers.data(name="label", shape=[1], dtype="float32") + one_hot_label = layers.one_hot(input=label, depth=10) + smooth_label = layers.label_smooth( + label=one_hot_label, epsilon=0.1, dtype="float32") + """ + if epsilon > 1. or epsilon < 0.: + raise ValueError("The value of epsilon must be between 0 and 1.") + helper = LayerHelper("label_smooth", **locals()) + label.stop_gradient = True + smooth_label = helper.create_tmp_variable(dtype) + helper.append_op( + type="label_smooth", + inputs={"X": label, + "PriorDist": prior_dist} if prior_dist else {"X": label}, + outputs={"Out": smooth_label}, + attrs={"epsilon": float(epsilon)}) + return smooth_label diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 99a81c1d4244b919a53dfec36fc5a6659c10adae..c618b02a768f2ca3e2b2914d8ee0134836d5c0d2 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -169,7 +169,7 @@ class Accuracy(MetricBase): return self.value / self.weight -class ChunkEvalutor(MetricBase): +class ChunkEvaluator(MetricBase): """ Accumulate counter numbers output by chunk_eval from mini-batches and compute the precision recall and F1-score using the accumulated counter @@ -177,7 +177,7 @@ class ChunkEvalutor(MetricBase): """ def __init__(self, name=None): - super(ChunkEvalutor, self).__init__(name) + super(ChunkEvaluator, self).__init__(name) self.num_infer_chunks = 0 self.num_label_chunks = 0 self.num_correct_chunks = 0 diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index f88a6f1ce6e953c54da29f9e96199169b2cecd8b..a1be2d671ddc5c689b16319fcf5bf12dca5dde7e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -340,6 +340,16 @@ class TestBook(unittest.TestCase): print(layers.lod_reset(x=x, y=y)) print(str(program)) + def test_label_smooth(self): + program = Program() + with program_guard(program): + label = layers.data(name="label", shape=[1], dtype="float32") + one_hot_label = layers.one_hot(input=label, depth=10) + smooth_label = layers.label_smooth( + label=one_hot_label, epsilon=0.1, dtype="float32") + self.assertIsNotNone(smooth_label) + print(str(program)) + if __name__ == '__main__': unittest.main()