diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 0c73778b201d77a6e8a35a38d17f2a86d5faaca9..4bd3f93ef75ada545751fef5af77a78e4872b690 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -35,4 +35,4 @@ function(inference_analysis_test TARGET) endif() endfunction(inference_analysis_test) -inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api) +inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS reset_tensor_array paddle_inference_api) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 27fb41d16ead65a1ec075399bcda135e2238c7ba..840abd26a755c39bc9c17315aefdd0dec862e77c 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -18,7 +18,7 @@ nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine conv_op conv_transpose_op SERIAL) nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc - DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op SERIAL) + DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine pool_op tensorrt_plugin SERIAL) nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_plugin elementwise_add_op elementwise_mul_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 48850020840a49bd309c007943f14b2f7eec5e2d..d700e08590ec5f9a397c3a6de80e0394c0dd4dc5 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -13,25 +13,57 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" namespace paddle { namespace inference { namespace tensorrt { +void DealCeilMode(const nvinfer1::Dims &input_shape, std::vector ksize, + std::vector strides, std::vector paddings, + nvinfer1::DimsHW *pre_pad, nvinfer1::DimsHW *post_pad, + int input_dims) { + int input_height = input_shape.d[input_dims - 2]; + int input_width = input_shape.d[input_dims - 1]; + int floor_h_output_size = + (input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + int ceil_h_output_size = + (input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) / + strides[0] + + 1; + + int floor_w_output_size = + (input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + int ceil_w_output_size = + (input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) / strides[1] + + 1; + if (floor_h_output_size != ceil_h_output_size) { + post_pad->h() = strides[0] - 1; + } + + if (floor_w_output_size != ceil_w_output_size) { + post_pad->w() = strides[1] - 1; + } +} + /* * Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights. */ class Pool2dOpConverter : public OpConverter { public: - void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { - VLOG(3) + void operator()(const framework::proto::OpDesc &op, + const framework::Scope &scope, bool test_mode) override { + VLOG(40) << "convert a fluid pool2d op to tensorrt pool2d layer without bias"; framework::OpDesc op_desc(op, nullptr); // Declare inputs PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); - auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + auto *input1 = engine_->GetITensor(op_desc.Input("X")[0]); + nvinfer1::Dims input_shape = input1->getDimensions(); + int input_dims = input_shape.nbDims; + + PADDLE_ENFORCE_EQ(input_dims, 3UL); bool global_pooling = boost::get(op_desc.GetAttr("global_pooling")); std::string pool_type = @@ -44,23 +76,6 @@ class Pool2dOpConverter : public OpConverter { boost::get>(op_desc.GetAttr("paddings")); bool ceil_mode = boost::get(op_desc.GetAttr("ceil_mode")); - nvinfer1::Dims input_shape = input1->getDimensions(); - int nbDims = input_shape.nbDims; - nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); - nvinfer1::DimsHW nv_strides(strides[0], strides[1]); - nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); - - if (global_pooling == true) { - nv_ksize.d[0] = input_shape.d[nbDims - 2]; - nv_ksize.d[1] = input_shape.d[nbDims - 1]; - nv_strides.h() = 1; - nv_strides.w() = 1; - nv_paddings.h() = 0; - nv_paddings.w() = 0; - } - - PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL); - nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX; if (pool_type == "max") { nv_pool_type = nvinfer1::PoolingType::kMAX; @@ -70,42 +85,63 @@ class Pool2dOpConverter : public OpConverter { PADDLE_THROW("TensorRT unsupported pooling type!"); } - if (ceil_mode) { - nvinfer1::DimsHW pre_pad(0, 0); - nvinfer1::DimsHW post_pad(0, 0); - int input_height = input_shape.d[nbDims - 2]; - int input_width = input_shape.d[nbDims - 1]; - int floor_h_output_size = - (input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1; - int ceil_h_output_size = - (input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) / - strides[0] + - 1; - - int floor_w_output_size = - (input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1; - int ceil_w_output_size = - (input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) / - strides[1] + - 1; - if (floor_h_output_size != ceil_h_output_size) { - post_pad.h() = strides[0] - 1; + nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_strides(strides[0], strides[1]); + nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); + + nvinfer1::ILayer *layer = nullptr; + + if (global_pooling == true) { + nv_ksize.d[0] = input_shape.d[input_dims - 2]; + nv_ksize.d[1] = input_shape.d[input_dims - 1]; + auto *layer = TRT_ENGINE_ADD_LAYER( + engine_, Pooling, *const_cast(input1), + nv_pool_type, nv_ksize); + PADDLE_ENFORCE_NOT_NULL(layer, "pool layer could not be created."); + auto output_name = op_desc.Output("Out")[0]; + layer->setName(("pool2d (Output: " + output_name + ")").c_str()); + layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); } + return; + } - if (floor_w_output_size != ceil_w_output_size) { - post_pad.w() = strides[1] - 1; + if (pool_type == "max") { + nvinfer1::DimsHW pre_pad(paddings[0], paddings[1]); + nvinfer1::DimsHW post_pad(paddings[0], paddings[1]); + if (ceil_mode) { + // If ceil mode is true, we will pad the appropriate size to the input. + DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad, + input_dims); + auto *pad_layer = TRT_ENGINE_ADD_LAYER( + engine_, Padding, *const_cast(input1), pre_pad, + post_pad); + PADDLE_ENFORCE_NOT_NULL( + pad_layer, "pad layer in poolOp converter could not be created."); + input1 = pad_layer->getOutput(0); + } + auto *pool_layer = TRT_ENGINE_ADD_LAYER( + engine_, Pooling, *const_cast(input1), + nv_pool_type, nv_ksize); + PADDLE_ENFORCE_NOT_NULL(pool_layer, "pool layer could not be created."); + pool_layer->setStride(nv_strides); + pool_layer->setPadding(nv_paddings); + layer = pool_layer; + } else { + // Average pooling needs to exclude the padding pixels from the average + // mean. + // It is not supported well by TRT, we use a plugin here. + std::vector input_shape_v; + for (int i = 0; i < input_dims; i++) { + input_shape_v.push_back(input_shape.d[i]); } - auto* layer = TRT_ENGINE_ADD_LAYER( - engine_, Padding, *const_cast(input1), pre_pad, - post_pad); - input1 = layer->getOutput(0); + plugin::AvgPoolPlugin *plugin = new plugin::AvgPoolPlugin( + ceil_mode, ksize, strides, paddings, input_shape_v); + auto *avg_pool_layer = engine_->AddPlugin(&input1, 1, plugin); + layer = avg_pool_layer; } - auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, - *const_cast(input1), - nv_pool_type, nv_ksize); - PADDLE_ENFORCE_NOT_NULL(layer, "pool layer could not be created."); - layer->setStride(nv_strides); - layer->setPadding(nv_paddings); auto output_name = op_desc.Output("Out")[0]; layer->setName(("pool2d (Output: " + output_name + ")").c_str()); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index ee597f8465c218c0fb6648374c128cabf7b033fb..bded833505cd25352adc4123de415613d1fc926d 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -20,20 +20,21 @@ namespace paddle { namespace inference { namespace tensorrt { -void test_pool2d(bool global_pooling, bool ceil_mode) { +void test_pool2d(bool global_pooling, bool ceil_mode, + std::string pool_type = "max") { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(5, parameters, scope, 1 << 15); // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. - validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14)); + validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 6, 7)); if (global_pooling) validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); else if (ceil_mode) - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7)); + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 3, 4)); else - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6)); + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 3, 3)); // Prepare Op description framework::OpDesc desc; @@ -41,10 +42,10 @@ void test_pool2d(bool global_pooling, bool ceil_mode) { desc.SetInput("X", {"pool2d-X"}); desc.SetOutput("Out", {"pool2d-Out"}); - std::vector ksize({3, 3}); + std::vector ksize({2, 2}); std::vector strides({2, 2}); std::vector paddings({0, 0}); - std::string pooling_t = "max"; + std::string pooling_t = pool_type; desc.SetAttr("pooling_type", pooling_t); desc.SetAttr("ksize", ksize); @@ -63,7 +64,8 @@ void test_pool2d(bool global_pooling, bool ceil_mode) { TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); } -TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); } +TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); } +TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); } } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index a0329325bea19bd9cdd3fcd39724cf05664b505a..e822785ad6f4f6f67b72141f3e7b04aefa72e58b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,3 +1,4 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu + avg_pool_op_plugin.cu DEPS enforce tensorrt_engine) diff --git a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..5d747af8c55d71fee90ee0cc06fd328e583f3700 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h" +#include "paddle/fluid/operators/math/pooling.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +nvinfer1::Dims AvgPoolPlugin::getOutputDimensions( + int index, const nvinfer1::Dims* inputDims, int nbInputs) { + assert(nbInputs == 1); + assert(index == 0); + assert(inputDims[0].nbDims == 3); + nvinfer1::Dims const& input_dims = inputDims[0]; + + nvinfer1::Dims output_dims = input_dims; + + output_dims.d[1] = output_shape_[1]; + output_dims.d[2] = output_shape_[2]; + return output_dims; +} + +int AvgPoolPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, void* workspace, + cudaStream_t stream) { + auto const& input_dims = this->getInputDims(0); + int input_size = 0; + float const* idata = reinterpret_cast(inputs[0]); + float** odatas = reinterpret_cast(outputs); + + paddle::operators::math::AvgPool pool_process; + paddle::operators::math::Pool2dDirectCUDAFunctor< + paddle::operators::math::AvgPool, float> + pool2d_forward; + + std::vector input_shape = input_shape_; + std::vector output_shape = output_shape_; + input_shape.insert(input_shape.begin(), batchSize); + output_shape.insert(output_shape.begin(), batchSize); + + pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, paddings_, + pool_process, true, odatas[0], stream); + + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..b5e4ece0fba446627d619df6fe225e8c07231487 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h @@ -0,0 +1,111 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class AvgPoolPlugin : public PluginTensorRT { + private: + bool ceil_mode_; + std::vector ksize_; + std::vector strides_; + std::vector paddings_; + std::vector input_shape_; + std::vector output_shape_; + + protected: + size_t getSerializationSize() override { + return SerializedSize(ceil_mode_) + SerializedSize(ksize_) + + SerializedSize(strides_) + SerializedSize(paddings_) + + SerializedSize(input_shape_) + getBaseSerializationSize(); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + serializeBase(buffer); + SerializeValue(&buffer, ceil_mode_); + SerializeValue(&buffer, ksize_); + SerializeValue(&buffer, strides_); + SerializeValue(&buffer, paddings_); + SerializeValue(&buffer, input_shape_); + } + + public: + AvgPoolPlugin(bool ceil_mode, std::vector ksize, + std::vector strides, std::vector paddings, + std::vector input_shape) + : ceil_mode_(ceil_mode), + ksize_(ksize), + strides_(strides), + paddings_(paddings), + input_shape_(input_shape) { + int output_h, output_w; + output_shape_ = input_shape_; + if (!ceil_mode_) { + output_h = + (input_shape[1] - ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1; + output_w = + (input_shape[2] - ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1; + } else { + output_h = + (input_shape[1] - ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / + strides_[0] + + 1; + output_w = + (input_shape[2] - ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / + strides_[1] + + 1; + } + output_shape_[1] = output_h; + output_shape_[2] = output_w; + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + AvgPoolPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &ceil_mode_); + DeserializeValue(&serialData, &serialLength, &ksize_); + DeserializeValue(&serialData, &serialLength, &strides_); + DeserializeValue(&serialData, &serialLength, &paddings_); + DeserializeValue(&serialData, &serialLength, &input_shape_); + } + + AvgPoolPlugin *clone() const override { + return new AvgPoolPlugin(ceil_mode_, ksize_, strides_, paddings_, + input_shape_); + } + + const char *getPluginType() const override { return "avg_pool"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int initialize() override { return 0; } + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/best_fit_allocator_test.cc b/paddle/fluid/memory/allocation/best_fit_allocator_test.cc index 4122b3d709e095c08b4fb2667103649a03eee64f..20748a23a1951383c888d9b8d7a360ec941e50cb 100644 --- a/paddle/fluid/memory/allocation/best_fit_allocator_test.cc +++ b/paddle/fluid/memory/allocation/best_fit_allocator_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/best_fit_allocator.h" +#include #include // NOLINT #include #include "gtest/gtest.h" diff --git a/paddle/fluid/memory/allocation/best_fit_allocator_test.cu b/paddle/fluid/memory/allocation/best_fit_allocator_test.cu index 50aecda97a9abb64f81c6e0e1d268e57a3aad3f0..f7f17e1d36e0adef0b0eb7a43715836db4b7927d 100644 --- a/paddle/fluid/memory/allocation/best_fit_allocator_test.cu +++ b/paddle/fluid/memory/allocation/best_fit_allocator_test.cu @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include // NOLINT #include #include "gtest/gtest.h" diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 975c3bfc3362413b9af0edf1a3e5b4b64635132d..81c9239486ad00fe702af32c2171906baeeeae28 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -34,29 +34,37 @@ if (WITH_GPU AND TENSORRT_FOUND) add_subdirectory(tensorrt) endif() -register_operators(EXCLUDES warpctc_op conv_fusion_op) +SET(OP_HEADER_DEPS xxhash) +if (WITH_GPU) + SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub) +endif() -# warpctc_cudnn need cudnn 7 above +register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS}) + +# warpctc_op needs cudnn 7 above if (WITH_GPU) if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() - op_library(conv_fusion_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") + # conv_fusion_op needs cudnn 7 above + if (NOT ${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + op_library(conv_fusion_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") + endif() else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() -set(COMMON_OP_DEPS "") +set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xxhash selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor dynload_warpctc sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor dynload_warpctc sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) if (NOT WIN32) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) endif() if (WITH_GPU) - set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv cub) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv) endif() # FIXME(typhoonzero): operator deps may not needed. diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc index bd1041ce0836014dc73fabd4a3896243a943bd38..2c09ee7394ad605f7a324d021ce0468a79bb71ca 100644 --- a/paddle/fluid/operators/conv_fusion_op.cu.cc +++ b/paddle/fluid/operators/conv_fusion_op.cu.cc @@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search); namespace paddle { namespace operators { +#if CUDNN_VERSION >= 7001 using Tensor = framework::Tensor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; @@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } }; +#endif } // namespace operators } // namespace paddle +#if CUDNN_VERSION >= 7001 namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel, ops::CUDNNConvFusionOpKernel); +#endif diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index a689eb42242e551caa3470f34f7e8d7e80b6dfbe..cdc79e207aa9a2e59e25a07002134c12ad5a1df8 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -153,6 +153,37 @@ __global__ void KernelMaxPool2DGrad( } } +template +void Pool2dDirectCUDAFunctor::operator()( + const T* input, const std::vector& input_shape, + const std::vector& output_shape, const std::vector& ksize, + const std::vector& strides, const std::vector& paddings, + PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) { + const int batch_size = input_shape[0]; + const int input_channels = input_shape[1]; + const int input_height = input_shape[2]; + const int input_width = input_shape[3]; + const int output_channels = output_shape[1]; + const int output_height = output_shape[2]; + const int output_width = output_shape[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool2D<<>>( + nthreads, input, input_channels, input_height, input_width, output_height, + output_width, ksize_height, ksize_width, stride_height, stride_width, + padding_height, padding_width, pool_compute, exclusive, output); +} + /* * All tensors are in NCHW format. * Ksize, strides, paddings are two elements. These two elements represent @@ -291,6 +322,11 @@ class MaxPool2dGradFunctor { } }; +template class Pool2dDirectCUDAFunctor, + float>; +template class Pool2dDirectCUDAFunctor, + float>; + template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 0f64e321bf01eea69767af020ed8c1a75e31acb5..923babd4c248364b735bb09def7bf12f2762f305 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -82,6 +82,19 @@ class AvgPoolGrad { * This is different from average pooling. So we rewrite the max_pool_grad: * MaxPool2dGradFunctor, MaxPool3dGradFunctor. */ +#ifdef PADDLE_WITH_CUDA +template +class Pool2dDirectCUDAFunctor { + public: + void operator()(const T* input, const std::vector& input_shape, + const std::vector& output_shape, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, PoolProcess pool_compute, + bool exclusive, T* output, cudaStream_t stream); +}; +#endif + template class Pool2dFunctor { public: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 99acd7e30884b46cb14e27ac4569af82af311a3a..7b0a3e2c82b55f7fc646f970c2df6f66b696a865 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -726,11 +726,11 @@ def dynamic_gru(input, create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|bool|None): The parameter attribute for the bias - of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates + of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates the bias in the update gate, reset gate and candidate calculations. - If it is set to False, no bias will be applied to the update gate, - reset gate and candidate calculations. If it is set to None or one - attribute of ParamAttr, dynamic_gru will create ParamAttr as + If it is set to False, no bias will be applied to the update gate, + reset gate and candidate calculations. If it is set to None or one + attribute of ParamAttr, dynamic_gru will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. is_reverse(bool): Whether to compute reversed GRU, default @@ -847,11 +847,11 @@ def gru_unit(input, create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|bool|None): The parameter attribute for the bias - of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates + of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates the bias in the update gate, reset gate and candidate calculations. - If it is set to False, no bias will be applied to the update gate, - reset gate and candidate calculations. If it is set to None or one - attribute of ParamAttr, gru_unit will create ParamAttr as + If it is set to False, no bias will be applied to the update gate, + reset gate and candidate calculations. If it is set to None or one + attribute of ParamAttr, gru_unit will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. activation (string): The activation type for cell (actNode). @@ -1064,9 +1064,9 @@ def dropout(x, inference: out = input (make is a tensor same shape with input, value is 0 or 1 ratio of 0 is dropout_prob) - dropout op can be removed from the program. + dropout op can be removed from the program. the program will be efficient - + Returns: @@ -2149,7 +2149,7 @@ def pool2d(input, ceil_mode (bool): ${ceil_mode_comment} name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. - exclusive (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is true Returns: @@ -2240,7 +2240,7 @@ def pool3d(input, ceil_mode (bool): ${ceil_mode_comment} name (str): A name for this layer(optional). If set None, the layer will be named automatically. - exclusive (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is true Returns: @@ -4342,7 +4342,7 @@ def nce(input, sampler (str): The sampler used to sample class from negtive classes. It can be 'uniform', 'log_uniform' or 'custom_dist'. default: 'uniform'. - custom_dist (Variable): A tensor with shape [num_total_classes]. + custom_dist (Variable): A tensor with shape [num_total_classes]. It is used when sampler is set to 'custom_dist'. custom_dist[i] is the probsbility of i-th class to be sampled. default: None. @@ -4385,7 +4385,7 @@ def nce(input, num_neg_samples=3, sampler="custom_dist", custom_dist=dist) - + """ helper = LayerHelper('nce', **locals()) assert isinstance(input, Variable) @@ -4556,9 +4556,9 @@ def transpose(x, perm, name=None): Examples: .. code-block:: python - # use append_batch_size=False to avoid prepending extra + # use append_batch_size=False to avoid prepending extra # batch size in shape - x = fluid.layers.data(name='x', shape=[5, 10, 15], + x = fluid.layers.data(name='x', shape=[5, 10, 15], dtype='float32', append_batch_size=False) x_transposed = layers.transpose(x, perm=[1, 0, 2]) """ @@ -4835,7 +4835,7 @@ def softmax_with_cross_entropy(logits, 3) If numeric_stable_mode is True, softmax is calculated first by: .. math:: - + max_j = \\max_{i=0}^{K}{\\text{logit}_i} log\\_max\\_sum_j = \\log\\sum_{i=0}^{K}\\exp(logit_i - max_j) @@ -4858,18 +4858,18 @@ def softmax_with_cross_entropy(logits, numeric_stable_mode (bool): A flag to indicate whether to use a more numerically stable algorithm. Only valid when soft_label is False and GPU is used. - When soft_label is True or CPU is used, - the algorithm is always numerically stable. - Note that the speed may be slower when use + When soft_label is True or CPU is used, + the algorithm is always numerically stable. + Note that the speed may be slower when use stable algorithm. Default: False - return_softmax (bool): A flag indicating whether to return the softmax + return_softmax (bool): A flag indicating whether to return the softmax along with the cross entropy loss. Default: False Returns: - Variable or Tuple of two Variables: Return the cross entropy loss if - `return_softmax` is False, otherwise the tuple - (loss, softmax), where the cross entropy loss is - a 2-D tensor with shape [N x 1], and softmax is a + Variable or Tuple of two Variables: Return the cross entropy loss if + `return_softmax` is False, otherwise the tuple + (loss, softmax), where the cross entropy loss is + a 2-D tensor with shape [N x 1], and softmax is a 2-D tensor with shape [N x K]. Examples: @@ -5756,20 +5756,20 @@ def image_resize(input, Default: None name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. - resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' + resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' currently. Default: 'BILINEAR' - actual_shape(Variable): An optional input to specify output shape - dynamically. If provided, image resize - according to this given shape rather than + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than :attr:`out_shape` and :attr:`scale` specifying - shape. That is to say actual_shape has the - highest priority. It is recommended to use - actual_shape instead of :attr:`out_shape` if you - want to specify output shape dynamically. When - using actual_shape to specify output shape, one of - :attr:`out_shape` and :attr:`scale` should also be - set, otherwise errors would be occured in graph + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph constructing stage. Default: None @@ -5780,7 +5780,7 @@ def image_resize(input, Raises: TypeError: out_shape should be a list or tuple or Variable. TypeError: actual_shape should either be Variable or None. - ValueError: The 'resample' of image_resize can only be 'BILINEAR' + ValueError: The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently. ValueError: One of out_shape and scale must not be None. ValueError: out_shape length should be 2. @@ -5852,17 +5852,17 @@ def resize_bilinear(input, name=None, actual_shape=None): """ - Resize input by performing bilinear interpolation based on given - output shape which specified by actual_shape, out_shape and scale + Resize input by performing bilinear interpolation based on given + output shape which specified by actual_shape, out_shape and scale in priority order. - Bilinear interpolation is an extension of linear interpolation for - interpolating functions of two variables (e.g. H-direction and - W-direction in this op) on a rectilinear 2D grid. The key idea is - to perform linear interpolation first in one direction, and then + Bilinear interpolation is an extension of linear interpolation for + interpolating functions of two variables (e.g. H-direction and + W-direction in this op) on a rectilinear 2D grid. The key idea is + to perform linear interpolation first in one direction, and then again in the other direction. - For details of bilinear interpolation, please refer to Wikipedia: + For details of bilinear interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation Args: @@ -5875,17 +5875,17 @@ def resize_bilinear(input, a higher priority than scale. Default: None. name(str|None): The output variable name. - actual_shape(Variable): An optional input to specify output shape - dynamically. If provided, image resize - according to this given shape rather than + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than :attr:`out_shape` and :attr:`scale` specifying - shape. That is to say actual_shape has the - highest priority. It is recommended to use - actual_shape instead of :attr:`out_shape` if you - want to specify output shape dynamically. When - using actual_shape to specify output shape, one of - :attr:`out_shape` and :attr:`scale` should also be - set, otherwise errors would be occured in graph + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph constructing stage. Default: None @@ -5909,11 +5909,11 @@ def resize_nearest(input, actual_shape=None): """ Resize input by performing nearest neighbor interpolation in both the - 3rd dimention(in height direction) and the 4th dimention(in width - direction) based on given output shape which specified by actual_shape, + 3rd dimention(in height direction) and the 4th dimention(in width + direction) based on given output shape which specified by actual_shape, out_shape and scale in priority order. - For details of nearest neighbor interpolation, please refer to Wikipedia: + For details of nearest neighbor interpolation, please refer to Wikipedia: https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation Args: @@ -5926,17 +5926,17 @@ def resize_nearest(input, a higher priority than scale. Default: None. name(str|None): The output variable name. - actual_shape(Variable): An optional input to specify output shape - dynamically. If provided, image resize - according to this given shape rather than + actual_shape(Variable): An optional input to specify output shape + dynamically. If provided, image resize + according to this given shape rather than :attr:`out_shape` and :attr:`scale` specifying - shape. That is to say actual_shape has the - highest priority. It is recommended to use - actual_shape instead of :attr:`out_shape` if you - want to specify output shape dynamically. When - using actual_shape to specify output shape, one of - :attr:`out_shape` and :attr:`scale` should also be - set, otherwise errors would be occured in graph + shape. That is to say actual_shape has the + highest priority. It is recommended to use + actual_shape instead of :attr:`out_shape` if you + want to specify output shape dynamically. When + using actual_shape to specify output shape, one of + :attr:`out_shape` and :attr:`scale` should also be + set, otherwise errors would be occured in graph constructing stage. Default: None @@ -6446,15 +6446,15 @@ def affine_grid(theta, out_shape, name=None): [x_14, x_15, x_16]] [[x_21, x_22, x_23] [x_24, x_25, x_26]]] - + out_shape = [2, 3, 5, 5] - + Step 1: - + Generate normalized coordinates according to out_shape. The values of the normalized coordinates are in the interval between -1 and 1. The shape of the normalized coordinates is [2, H, W] as below: - + C = [[[-1. -1. -1. -1. -1. ] [-0.5 -0.5 -0.5 -0.5 -0.5] [ 0. 0. 0. 0. 0. ] @@ -7702,6 +7702,15 @@ def logical_and(x, y, out=None, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + left = fluid.layers.data( + name='left', shape=[1], dtype='int32') + right = fluid.layers.data( + name='right', shape=[1], dtype='int32') + result = fluid.layers.logical_and(x=left, y=right) """ return _logical_op( @@ -7721,6 +7730,15 @@ def logical_or(x, y, out=None, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + left = fluid.layers.data( + name='left', shape=[1], dtype='int32') + right = fluid.layers.data( + name='right', shape=[1], dtype='int32') + result = fluid.layers.logical_or(x=left, y=right) """ return _logical_op( @@ -7740,6 +7758,15 @@ def logical_xor(x, y, out=None, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + left = fluid.layers.data( + name='left', shape=[1], dtype='int32') + right = fluid.layers.data( + name='right', shape=[1], dtype='int32') + result = fluid.layers.logical_xor(x=left, y=right) """ return _logical_op( @@ -7758,6 +7785,13 @@ def logical_not(x, out=None, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + left = fluid.layers.data( + name='left', shape=[1], dtype='int32') + result = fluid.layers.logical_not(x=left) """ return _logical_op( @@ -7777,6 +7811,13 @@ def clip(x, min, max, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + input = fluid.layers.data( + name='data', shape=[1], dtype='float32') + reward = fluid.layers.clip(x=input, min=-1.0, max=1.0) """ helper = LayerHelper("clip", **locals()) @@ -7809,6 +7850,13 @@ def clip_by_norm(x, max_norm, name=None): Returns: out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + input = fluid.layers.data( + name='data', shape=[1], dtype='float32') + reward = fluid.layers.clip_by_norm(x=input, max_norm=1.0) """ helper = LayerHelper("clip_by_norm", **locals()) @@ -7954,19 +8002,19 @@ def maxout(x, groups, name=None): def space_to_depth(x, blocksize, name=None): """ Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width] - - This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the - input LoDtensor where values from the height and width dimensions are moved to the channel dimension. + + This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the + input LoDtensor where values from the height and width dimensions are moved to the channel dimension. The attr blocksize indicates the input block size. - - space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according + + space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]: - - space_to_depth is used to This operation is useful for resizing the activations between convolutions + + space_to_depth is used to This operation is useful for resizing the activations between convolutions (but keeping all data) - Non-overlapping blocks of size block_size x block size are rearranged into depth at each location. - - The depth of the output tensor is block_size * block_size * input channel + - The depth of the output tensor is block_size * block_size * input channel - The Y, X coordinates within each block of the input become the high order component of the output channel index - channel should be divisible by square of blocksize - height, width should be divsible by blocksize @@ -8013,7 +8061,7 @@ def space_to_depth(x, blocksize, name=None): @templatedoc() def sequence_reverse(x, name=None): - """ + """ ${comment} Args: @@ -8080,21 +8128,21 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): def similarity_focus(input, axis, indexes, name=None): - """ + """ SimilarityFocus Operator Generate a similarity focus mask with the same shape of input using the following method: - 1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding - to the axis according to the indexes. For example, if axis=1 and indexes=[a], - it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X + 1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding + to the axis according to the indexes. For example, if axis=1 and indexes=[a], + it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C). - 2. For each index, find the largest numbers in the tensor T, so that the same - row and same column has at most one number(what it means is that if the - largest number has been found in the i-th row and the j-th column, then - the numbers in the i-th row or j-th column will be skipped. And then the - next largest number will be selected from the remaining numbers. Obviously - there will be min(B, C) numbers), and mark the corresponding position of the - 3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for + 2. For each index, find the largest numbers in the tensor T, so that the same + row and same column has at most one number(what it means is that if the + largest number has been found in the i-th row and the j-th column, then + the numbers in the i-th row or j-th column will be skipped. And then the + next largest number will be selected from the remaining numbers. Obviously + there will be min(B, C) numbers), and mark the corresponding position of the + 3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for each index. 3. Broadcast the 3-D similarity focus mask to the same shape of input X. @@ -8150,16 +8198,16 @@ def similarity_focus(input, axis, indexes, name=None): [1.0, 0.0]]]] Args: - input(Variable): The input tensor variable(default float). It should + input(Variable): The input tensor variable(default float). It should be a 4-D tensor with shape [BatchSize, A, B, C]. axis(int): Indicating the dimension to be selected. It can only be 1, 2 or 3. indexes(list): Indicating the indexes of the selected dimension. Returns: - Variable: A tensor variable with the same shape and same type + Variable: A tensor variable with the same shape and same type as the input. - + Examples: .. code-block:: python data = fluid.layers.data( @@ -8262,12 +8310,12 @@ def hash(input, hash_size, num_hash=1, name=None): @templatedoc() def grid_sampler(x, grid, name=None): """ - This operation samples input X by using bilinear interpolation based on + This operation samples input X by using bilinear interpolation based on flow field grid, which is usually gennerated by affine_grid. The grid of - shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates - with shape [N, H, W] each, where grid_x is indexing the 4th dimension - (in width dimension) of input data x and grid_y is indexng the 3rd - dimention (in height dimension), finally results is the bilinear + shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates + with shape [N, H, W] each, where grid_x is indexing the 4th dimension + (in width dimension) of input data x and grid_y is indexng the 3rd + dimention (in height dimension), finally results is the bilinear interpolation value of 4 nearest corner points. Step 1: @@ -8277,7 +8325,7 @@ def grid_sampler(x, grid, name=None): grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1) Step 2: - Indices input data X with grid (x, y) in each [H, W] area, and bilinear + Indices input data X with grid (x, y) in each [H, W] area, and bilinear interpolate point value by 4 nearest points. wn ------- y_n ------- en @@ -8314,7 +8362,7 @@ def grid_sampler(x, grid, name=None): name (str, default None): The name of this layer. Returns: - out(Variable): Output of shape [N, C, H, W] data samples input X + out(Variable): Output of shape [N, C, H, W] data samples input X using bilnear interpolation based on input grid. Exmples: diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 29e4ca04a7fbb2eae870fcf15763310b849c8b53..510d0304f04ff1e7dc4c9bf37169778f85ca6db8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE) LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) endif(NOT WITH_DISTRIBUTE) +if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) +endif() + list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185 diff --git a/python/requirements.txt b/python/requirements.txt index 84cf440397b994ba12fa70d9e316e788f34e2415..2f81d85df0626b294f4d861706b5c1b7ec9841d5 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,5 +1,5 @@ requests==2.9.2 -numpy>=1.12,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version +numpy>=1.12 protobuf==3.1 recordio>=0.1.0 matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib