diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b77659f6142da3c8b6bb4913a8219683b723a76..9ad69738eb2ac21d6ff2624f11d17a38410d5c1f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,7 +75,6 @@ option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) -option(WITH_WBAES "Compile PaddlePaddle with WBAES support" ON) # PY_VERSION if(NOT PY_VERSION) @@ -149,7 +148,6 @@ include(external/dlpack) include(external/snappy) # download snappy include(external/snappystream) # download snappystream include(external/warpctc) # download, build, install warpctc -include(external/wbaes) # download wbaes if (NOT WIN32) # there is no official support of nccl, cupti in windows diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 283845541b8e303babeed7ed9f9ece2d51a6a2fc..93d74bb0a8f726ad31685cbfc7831b5441cd5108 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -157,7 +157,3 @@ endif(WITH_BRPC_RDMA) if(ON_INFER) add_definitions(-DPADDLE_ON_INFERENCE) endif(ON_INFER) - -if(WITH_WBAES) - add_definitions(-DPADDLE_WITH_WBAES) -endif(WITH_WBAES) diff --git a/cmake/external/wbaes.cmake b/cmake/external/wbaes.cmake deleted file mode 100644 index feda5cb367aeb532702c9ab8560388d1207c201c..0000000000000000000000000000000000000000 --- a/cmake/external/wbaes.cmake +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -IF(NOT ${WITH_WBAES}) - return() -ENDIF(NOT ${WITH_WBAES}) - -INCLUDE(ExternalProject) -SET(WBAES_DST_DIR "wbaes") -SET(WBAES_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") -SET(WBAES_INSTALL_DIR ${WBAES_INSTALL_ROOT}/${WBAES_DST_DIR}) -SET(WBAES_ROOT ${WBAES_INSTALL_DIR}) -SET(WBAES_INC_DIR ${WBAES_ROOT}/include) -SET(WBAES_LIB_DIR ${WBAES_ROOT}/lib) - -SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${WBAES_ROOT}/lib") -SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) - -IF(APPLE) - SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE) - SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.mac.${WBAES_TAG}.tgz" CACHE STRING "" FORCE) - SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.dylib) - SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dylib) -ELSEIF(WIN32) - SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE) - SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.windows-x64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE) - SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.lib) - SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dll) -ELSE() - SET(WBAES_TAG "v1.0.2" CACHE STRING "" FORCE) - SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.linux-x86_64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE) - SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.so) - SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.so) -ENDIF() - -SET(WBAES_PROJECT "extern_wbaes") -MESSAGE(STATUS "WBAES_URL: ${WBAES_URL}, WBAES_LIB: ${WBAES_LIB}") -SET(WBAES_SOURCE_DIR "${THIRD_PARTY_PATH}/wbaes") -SET(WBAES_DOWNLOAD_DIR "${WBAES_SOURCE_DIR}/src/${WBAES_PROJECT}") - -ExternalProject_Add( - ${WBAES_PROJECT} - ${EXTERNAL_PROJECT_LOG_ARGS} - PREFIX ${WBAES_SOURCE_DIR} - URL ${WBAES_URL} - DOWNLOAD_DIR ${WBAES_DOWNLOAD_DIR} - DOWNLOAD_NO_PROGRESS 1 - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - ${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/include ${WBAES_INC_DIR} && - ${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/lib ${WBAES_LIB_DIR} -) - -INCLUDE_DIRECTORIES(${WBAES_INC_DIR}) - -ADD_LIBRARY(wbaes SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_LOCATION ${WBAES_LIB}) -SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_NO_SONAME 1) -ADD_DEPENDENCIES(wbaes ${WBAES_PROJECT}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 19110812c240db4cbe3ba73a3a42ab0f1511a115..6679a09dfc9dd00cfe3b5c5da3e12bd1c1389432 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -264,14 +264,6 @@ function(cc_library TARGET_NAME) list(REMOVE_ITEM cc_library_DEPS warpctc) add_dependencies(${TARGET_NAME} warpctc) endif() - # Only deps libwbaes.so, not link - if("${cc_library_DEPS};" MATCHES "wbaes;") - list(REMOVE_ITEM cc_library_DEPS wbaes) - if(NOT "${TARGET_NAME}" MATCHES "dynload_wbaes") - list(APPEND cc_library_DEPS dynload_wbaes) - endif() - add_dependencies(${TARGET_NAME} wbaes) - endif() # Only deps libmklml.so, not link if("${cc_library_DEPS};" MATCHES "mklml;") list(REMOVE_ITEM cc_library_DEPS mklml) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 2f558bffbd11a59699e050e6c8a53bca4cbb0884..b7c32f80db0dcb826f3f67ffb55da1c715785add 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -170,14 +170,6 @@ copy(snappystream_lib DSTS ${dst_dir} ${dst_dir}/lib DEPS snappystream) -if (WITH_WBAES) - set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/wbaes") - copy(wbaes_lib - SRCS ${WBAES_INC_DIR} ${WBAES_LIB} - DSTS ${dst_dir} ${dst_dir}/lib - DEPS wbaes) -endif () - set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib") copy(zlib_lib SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b19d50a6ad6afa312f5e695583174e56bf490755..d71d792b4e0147afc0ec09808d2004cb4b48ba46 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -235,6 +235,7 @@ paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], vararg paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec')) paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607')) paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) +paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', 'ad669cdf83e72a69ebc5ed79e36486de')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc')) diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 0d116a6495477ca69c10c130e63247a4f6c03b23..e52a0283f726640eb56b24a2978af6ee44e658ff 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -3,4 +3,7 @@ cc_library(layer SRCS layer.cc DEPS proto_desc operator device_context blas pybi cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context pybind) cc_library(engine SRCS engine.cc) cc_library(imperative_profiler SRCS profiler.cc) +cc_library(nccl_context SRCS nccl_context.cc DEPS device_context) + +cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) endif() diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..04364e68342810f6babee1df0c5eb3b476de022a --- /dev/null +++ b/paddle/fluid/imperative/nccl_context.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/nccl_context.h" + +namespace paddle { +namespace imperative { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +void NCCLParallelContext::RecvNCCLID(const std::string &ep, + ncclUniqueId *nccl_id) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ(addr.size(), 2UL, + "The endpoint should contain host and port: %s", ep); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + + int server_fd, new_socket; + struct sockaddr_in address; + int addrlen = sizeof(address); + char buffer[1024] = {0}; + int opt = 0; + // creating socket fd + if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) + PADDLE_THROW("create server fd failed"); + if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, + sizeof(opt))) + PADDLE_THROW("set socket opt failed"); + + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) + PADDLE_THROW("binding failed on ep: %s", ep); + VLOG(3) << "listening on: " << ep; + if (listen(server_fd, 3) < 0) PADDLE_THROW("listen on server fd failed"); + + if ((new_socket = + accept(server_fd, reinterpret_cast(&address), + reinterpret_cast(&addrlen))) < 0) + PADDLE_THROW("accept the new socket fd failed"); + + if (read(new_socket, buffer, 1024) < 0) + PADDLE_THROW("reading the ncclUniqueId from socket failed"); + VLOG(3) << "recevived the ncclUniqueId"; + memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES); + + VLOG(3) << "closing the socket server: " << ep; + close(server_fd); +} + +void NCCLParallelContext::SendNCCLID(const std::string &ep, + ncclUniqueId *nccl_id) { + auto addr = paddle::string::Split(ep, ':'); + PADDLE_ENFORCE_EQ(addr.size(), 2UL, + "The endpoint should contain host and port: %s", ep); + std::string host = addr[0]; + int port = std::stoi(addr[1]); + // struct sockaddr_in address; + int sock = 0; + struct sockaddr_in serv_addr; + char buffer[1024] = {0}; + + memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) + PADDLE_THROW("create socket failed"); + + memset(&serv_addr, '0', sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(port); + + if (inet_pton(AF_INET, host.c_str(), &serv_addr.sin_addr) <= 0) + PADDLE_THROW("invalied address: %s", ep); + + while (true) { + if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { + VLOG(0) << "worker: " << ep + << " is not ready, will retry after 3 seconds..."; + std::this_thread::sleep_for(std::chrono::seconds(3)); + continue; + } + VLOG(3) << "sending the ncclUniqueId to " << ep; + send(sock, buffer, NCCL_UNIQUE_ID_BYTES, 0); + break; + } +} + +void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { + if (strategy_.local_rank_ == root) { + for (auto ep : strategy_.trainer_endpoints_) { + if (ep != strategy_.current_endpoint_) SendNCCLID(ep, nccl_id); + } + } else { + RecvNCCLID(strategy_.current_endpoint_, nccl_id); + } +} + +void NCCLParallelContext::Init() { + ncclUniqueId nccl_id; + ncclComm_t comm; + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + platform::dynload::ncclGetUniqueId(&nccl_id); + BcastNCCLId(&nccl_id, 0); + } else { + BcastNCCLId(&nccl_id, 0); + } + int gpu_id = boost::get(place_).device; + VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; + + PADDLE_ENFORCE(cudaSetDevice(gpu_id)); + PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( + &comm, strategy_.nranks_, nccl_id, strategy_.local_rank_)); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = static_cast(pool.Get(place_)); + dev_ctx->set_nccl_comm(comm); +} +#endif + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h new file mode 100644 index 0000000000000000000000000000000000000000..51c86190439d1739a99bc91a712f663383bb9371 --- /dev/null +++ b/paddle/fluid/imperative/nccl_context.h @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +// network header files +#ifndef _WIN32 +#include +#include +#include +#include +#endif + +#include +#include + +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/nccl.h" +#endif +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace imperative { + +struct ParallelStrategy { + int nranks_{1}; + int local_rank_{0}; + std::vector trainer_endpoints_{}; + std::string current_endpoint_{""}; +}; + +class ParallelContext { + public: + explicit ParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : strategy_(strategy), place_(place) {} + + virtual ~ParallelContext() {} + + virtual void Init() = 0; + + protected: + ParallelStrategy strategy_; + platform::Place place_; +}; + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +class NCCLParallelContext : ParallelContext { + public: + explicit NCCLParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : ParallelContext(strategy, place) {} + + ~NCCLParallelContext() {} + + void BcastNCCLId(ncclUniqueId* nccl_id, int root); + + void Init() override; + + protected: + void RecvNCCLID(const std::string& endpoint, ncclUniqueId* nccl_id); + + void SendNCCLID(const std::string& endpoint, ncclUniqueId* nccl_id); +}; +#endif + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/nccl_context_test.cc b/paddle/fluid/imperative/nccl_context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74a74ebe921378e2994a6a4cb2087d0acde950b1 --- /dev/null +++ b/paddle/fluid/imperative/nccl_context_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/nccl_context.h" +#include "gtest/gtest.h" +#include "paddle/fluid/platform/device_context.h" + +namespace imperative = paddle::imperative; +namespace platform = paddle::platform; + +imperative::ParallelStrategy GetStrategy(int local_rank) { + std::vector eps = {"127.0.0.1:9866", "127.0.0.1:9867"}; + imperative::ParallelStrategy strategy; + strategy.trainer_endpoints_ = eps; + strategy.current_endpoint_ = eps[local_rank]; + strategy.nranks_ = 2; + strategy.local_rank_ = local_rank; + return strategy; +} + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +void BcastNCCLId(int local_rank, ncclUniqueId *nccl_id) { + auto strategy = GetStrategy(local_rank); + platform::CUDAPlace gpu(local_rank); + imperative::NCCLParallelContext ctx(strategy, gpu); + ctx.BcastNCCLId(nccl_id, 0); +} + +TEST(BcastNCCLId, Run) { + ncclUniqueId nccl_id; + platform::dynload::ncclGetUniqueId(&nccl_id); + std::thread t(BcastNCCLId, 0, &nccl_id); + + ncclUniqueId recv_nccl_id; + BcastNCCLId(1, &recv_nccl_id); + + t.join(); + EXPECT_EQ(0, std::memcmp(nccl_id.internal, recv_nccl_id.internal, + NCCL_UNIQUE_ID_BYTES)); +} +#endif diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 7c9d0af3ecd647604ab46ee6239fc352e5fd8d85..7c495ddd68221acfed8537fd72e9a582e891f8db 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -177,7 +177,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, current_vars_map[out->Name()] = out; } - VLOG(3) << "input var name: " << out->Name() + VLOG(3) << "output var name: " << out->Name() << " inited: " << out->var_->IsInitialized() << " stop_grad: " << out->IsStopGradient(); } @@ -215,6 +215,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, framework::Scope scope; op->place_ = GetExpectedPlace(expected_place, inputs); + PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_); prepared_op.op.RuntimeInferShape(scope, op->place_, ctx); prepared_op.func( diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e765c078aa838de6513e6f4d6729e3b1fb2958db..1a87ac378a232f153a994c6b11ff37f8f5419a36 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -86,7 +86,8 @@ const std::vector kAnakinSubgraphPasses({ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ - "infer_clean_graph_pass", // + "infer_clean_graph_pass", // + "runtime_context_cache_pass", // // "identity_scale_op_clean_pass", // "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // @@ -96,7 +97,6 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add_fuse_pass", // - "runtime_context_cache_pass", // #endif // "transpose_flatten_concat_fuse_pass", }); @@ -116,7 +116,11 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. passes_.assign({ - "infer_clean_graph_pass", // + "infer_clean_graph_pass", // + // TODO(luotao): runtime_context_cache_pass should be located in the + // front, see https://github.com/PaddlePaddle/Paddle/issues/16609, + // will enhance this pass later. + "runtime_context_cache_pass", // "attention_lstm_fuse_pass", // "seqpool_concat_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // @@ -132,7 +136,6 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // - "runtime_context_cache_pass", // }); use_gpu_ = false; diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc index ece094717b8076321c68d7fdd29f07c4da6b0ed4..fbf67d933786e3ee2baab7a20911da2837cdce4d 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -23,18 +23,11 @@ namespace analysis { void SetConfig(AnalysisConfig *cfg) { cfg->SetModel(FLAGS_infer_model); - cfg->SetProgFile("__model__"); cfg->DisableGpu(); cfg->SwitchIrOptim(); - cfg->SwitchSpecifyInputNames(false); + cfg->SwitchSpecifyInputNames(); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->EnableMKLDNN(); - cfg->pass_builder()->SetPasses( - {"infer_clean_graph_pass", "mkldnn_placement_pass", - "depthwise_conv_mkldnn_pass", "conv_bn_fuse_pass", - "conv_eltwiseadd_bn_fuse_pass", "conv_bias_mkldnn_fuse_pass", - "conv_elementwise_add_mkldnn_fuse_pass", "conv_relu_mkldnn_fuse_pass", - "fc_fuse_pass", "is_test_pass"}); } template @@ -84,13 +77,13 @@ std::shared_ptr> GetWarmupData( std::to_string(num_images) + " is bigger than all test data size."); PaddleTensor images; - images.name = "input"; + images.name = "image"; images.shape = {num_images, 3, 224, 224}; images.dtype = PaddleDType::FLOAT32; images.data.Resize(sizeof(float) * num_images * 3 * 224 * 224); PaddleTensor labels; - labels.name = "labels"; + labels.name = "label"; labels.shape = {num_images, 1}; labels.dtype = PaddleDType::INT64; labels.data.Resize(sizeof(int64_t) * num_images); @@ -132,7 +125,7 @@ void SetInput(std::vector> *inputs, images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224; TensorReader image_reader(file, images_offset_in_file, - image_batch_shape, "input"); + image_batch_shape, "image"); TensorReader label_reader(file, labels_offset_in_file, label_batch_shape, "label"); diff --git a/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md b/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..cbeef5fb9da42388eade6fa90344abf77cb59bd6 --- /dev/null +++ b/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md @@ -0,0 +1,70 @@ +# INT8 MKL-DNN quantization + +This document describes how to use Paddle inference Engine to convert the FP32 model to INT8 model on ResNet-50 and MobileNet-V1. We provide the instructions on enabling INT8 MKL-DNN quantization in Paddle inference and show the ResNet-50 and MobileNet-V1 results in accuracy and performance. + +## 0. Install PaddlePaddle +Follow PaddlePaddle [installation instruction](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/image_classification#installation) to install PaddlePaddle. If you build PaddlePaddle yourself, please use the following cmake arguments. +``` +cmake .. -DWITH_TESTING=ON -WITH_FLUID_ONLY=ON -DWITH_GPU=OFF -DWITH_MKL=ON -WITH_SWIG_PY=OFF -DWITH_INFERENCE_API_TEST=ON -DON_INFER=ON + +``` +Note: MKL-DNN and MKL are required. + +## 1. Enable INT8 MKL-DNN quantization +For reference, please examine the code of unit test enclosed in [analyzer_int8_image_classification_tester.cc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc). + +* ### Create Analysis config +INT8 quantization is one of the optimizations in analysis config. More information about analysis config can be found [here](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/deploy/inference/native_infer_en.md#upgrade-performance-based-on-contribanalysisconfig-prerelease) + +* ### Create quantize config by analysis config +We enable the MKL-DNN quantization procedure by calling an appropriate method from analysis config. Afterwards, all the required quantization parameters (quantization op names, quantization strategies etc.) can be set through quantizer config which is present in the analysis config. It is also necessary to specify a pre-processed warmup dataset and desired batch size. + +```cpp +//Enable MKL-DNN quantization +cfg.EnableMkldnnQuantizer(); + +//use analysis config to call the MKL-DNN quantization config +cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); +cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100); +``` + +## 2. Accuracy and Performance benchmark + +We provide the results of accuracy and performance measured on Intel(R) Xeon(R) Gold 6271 on single core. + + >**I. Top-1 Accuracy on Intel(R) Xeon(R) Gold 6271** + +| Model | Dataset | FP32 Accuracy | INT8 Accuracy | Accuracy Diff | +| :------------: | :------------: | :------------: | :------------: | :------------: | +| ResNet-50 | Full ImageNet Val | 76.63% | 76.48% | 0.15% | +| MobileNet-V1 | Full ImageNet Val | 70.78% | 70.36% | 0.42% | + + >**II. Throughput on Intel(R) Xeon(R) Gold 6271 (batch size 1 on single core)** + +| Model | Dataset | FP32 Throughput | INT8 Throughput | Ratio(INT8/FP32) | +| :------------: | :------------: | :------------: | :------------: | :------------: | +| ResNet-50 | Full ImageNet Val | 13.17 images/s | 49.84 images/s | 3.78 | +| MobileNet-V1 | Full ImageNet Val | 75.49 images/s | 232.38 images/s | 3.07 | + +Notes: +* Measurement of accuracy requires a model which accepts two inputs: data and labels. +* Different sampling batch size data may cause slight difference on INT8 top accuracy. +* CAPI performance data is better than python API performance data because of the python overhead. Especially for the small computational model, python overhead will be more obvious. + + +## 3. Commands to reproduce the above accuracy and performance benchmark +* #### Full dataset (Single core) + * ##### Download full ImageNet Validation Dataset +```bash +cd /PATH/TO/PADDLE/build +python ../paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +``` +The converted data binary file is saved by default in ~/.cache/paddle/dataset/int8/download/int8_full_val.bin + * ##### ResNet50 Full dataset benchmark +```bash +./paddle/fluid/inference/tests/api/test_analyzer_int8_resnet50 --infer_model=third_party/inference_demo/int8v2/resnet50/model --infer_data=/path/to/converted/int8_full_val.bin --batch_size=1 --paddle_num_threads=1 +``` + * ##### Mobilenet-v1 Full dataset benchmark +```bash +./paddle/fluid/inference/tests/api/test_analyzer_int8_mobilenet --infer_model=third_party/inference_demo/int8v2/mobilenet/model --infer_data=/path/to/converted/int8_full_val.bin --batch_size=1 --paddle_num_threads=1 +``` diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 9a0dcc722cf00984b8c0e3ac20f13849e2904102..5cc54ed299c50b48c83de2742b715b16cf1f8cd0 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -316,7 +316,8 @@ void PredictionRun(PaddlePredictor *predictor, int num_threads, int tid) { int num_times = FLAGS_repeat; int iterations = inputs.size(); // process the whole dataset ... - if (FLAGS_iterations > 0 && FLAGS_iterations < inputs.size()) + if (FLAGS_iterations > 0 && + FLAGS_iterations < static_cast(inputs.size())) iterations = FLAGS_iterations; // ... unless the number of iterations is set outputs->resize(iterations); @@ -329,14 +330,14 @@ void PredictionRun(PaddlePredictor *predictor, #endif if (!FLAGS_zero_copy) { run_timer.tic(); - for (size_t i = 0; i < iterations; i++) { + for (int i = 0; i < iterations; i++) { for (int j = 0; j < num_times; j++) { predictor->Run(inputs[i], &(*outputs)[i], FLAGS_batch_size); } } elapsed_time = run_timer.toc(); } else { - for (size_t i = 0; i < iterations; i++) { + for (int i = 0; i < iterations; i++) { ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]); run_timer.tic(); for (int j = 0; j < num_times; j++) { @@ -366,9 +367,8 @@ void TestOneThreadPrediction( const std::vector> &inputs, std::vector> *outputs, bool use_analysis = true) { auto predictor = CreateTestPredictor(config, use_analysis); - PredictionWarmUp(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads, - 0); - PredictionRun(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads, 0); + PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0); + PredictionRun(predictor.get(), inputs, outputs, 1, 0); } void TestMultiThreadPrediction( diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index dd28f82b65403550c67418cae535bbfeeef4476e..f0dc718195506e89bf9fecc0eb5e0d5117275a33 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include +#include +#include +#include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" @@ -174,24 +177,41 @@ class ConditionalBlockGradOp : public ConditionalOp { framework::Executor exec(dev_place); auto *block = Attr("sub_block"); - exec.Run(*block->Program(), &cur_scope, block->ID(), false); - AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"), - Outputs(framework::GradVarName("Input"))); + const auto &ins = Inputs("Input"); + const auto &d_ins = Outputs(framework::GradVarName("Input")); + const auto &conds = Inputs("Cond"); + const auto &d_conds = Outputs(framework::GradVarName("Cond")); + + std::vector ins_conds_grads; + ins_conds_grads.reserve(ins.size() + conds.size()); + for (auto &in : ins) { + ins_conds_grads.emplace_back(framework::GradVarName(in)); + } + for (auto &cond : conds) { + ins_conds_grads.emplace_back(framework::GradVarName(cond)); + } + + exec.Run(*block->Program(), &cur_scope, block->ID(), false, true, + ins_conds_grads); + + AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(), + ins.size(), d_ins); - AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"), - Outputs(framework::GradVarName("Cond"))); + AssignLocalGradientToGlobal(dev_place, cur_scope, + ins_conds_grads.data() + ins.size(), + conds.size(), d_conds); } } private: void AssignLocalGradientToGlobal( const platform::Place &place, const framework::Scope &cur_scope, - const std::vector &p_names, + const std::string *p_grad_names, size_t p_grad_names_num, const std::vector &pg_names) const { - for (size_t i = 0; i < p_names.size(); ++i) { + for (size_t i = 0; i < p_grad_names_num; ++i) { auto out_grad_name = pg_names[i]; - auto in_grad_name = framework::GradVarName(p_names[i]); + const auto &in_grad_name = p_grad_names[i]; auto *in_var = cur_scope.FindVar(in_grad_name); if (in_var == nullptr) { continue; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index d30fa014ed5fbac9ed71f3185ce0443d33f4a281..875d4f864353c131ca4d72b5176adcae8aff724a 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -991,15 +991,17 @@ TEST(JITKernel_pool, jitpool) { TEST(JITKernel_pool, more) { const auto& kers = jit::KernelPool::Instance().AllKernels(); -#if defined(__APPLE__) || defined(__OSX__) - EXPECT_EQ(kers.size(), 10UL); -#else -#ifdef PADDLE_WITH_MKLML - EXPECT_EQ(kers.size(), 22UL); -#else - EXPECT_EQ(kers.size(), 8UL); + size_t target_num = 8; + +#ifdef __AVX__ + target_num += 2; #endif + +#ifdef PADDLE_WITH_MKLML + target_num += 12; #endif + + EXPECT_EQ(kers.size(), target_num); } TEST(JITKernel_pool, refer) { diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 656728c609eb19f90390d9dec72d9e30fd3040fd..435c755df3642ae0ba5144a89ed30ed6e0b63258 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -29,7 +29,7 @@ class LoadOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::OpKernelType kt = framework::OpKernelType( - framework::proto::VarType::FP32, platform::CPUPlace()); + framework::proto::VarType::FP32, ctx.GetPlace()); return kt; } }; diff --git a/paddle/fluid/operators/pixel_shuffle_op.cc b/paddle/fluid/operators/pixel_shuffle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59ba660af79bff02cd350afb3eb7675bfe8ac498 --- /dev/null +++ b/paddle/fluid/operators/pixel_shuffle_op.cc @@ -0,0 +1,135 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/pixel_shuffle_op.h" +#include + +namespace paddle { +namespace operators { + +class PixelShuffleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PixelShuffleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PixelShuffleOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + auto upscale_factor = ctx->Attrs().Get("upscale_factor"); + + PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0, + "Upscale_factor should devide the number of channel"); + + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor); + output_dims[2] = input_dims[2] * upscale_factor; + output_dims[3] = input_dims[3] * upscale_factor; + ctx->SetOutputDim("Out", output_dims); + } +}; + +class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor, default Tensor), " + "the input feature data of PixelShuffleOp, the layout is [N C H W]."); + AddOutput( + "Out", + "(Tensor, default Tensor), the output of " + "PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor]."); + AddAttr("upscale_factor", + "the factor to increase spatial resolution by.") + .SetDefault(1) + .AddCustomChecker([](const int& upscale_factor) { + PADDLE_ENFORCE_GE(upscale_factor, 1, + "upscale_factor should be larger than 0."); + }); + + AddComment(R"DOC( + Pixel Shuffle operator + This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(C, H \times r, W \times r)`. + + This is useful for implementing efficient sub-pixel convolution + with a stride of :math:`1/r`. + + Please refer to the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + by Shi et. al (2016) for more details. + + )DOC"); + } +}; + +class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("pixel_shuffle_grad"); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetAttrMap(Attrs()); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +class PixelShuffleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) should not be null"); + + auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW."); + + auto upscale_factor = ctx->Attrs().Get("upscale_factor"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor); + dx_dims[2] = do_dims[2] / upscale_factor; + dx_dims[3] = do_dims[3] / upscale_factor; + ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker, + ops::PixelShuffleGradMaker); + +REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp); + +REGISTER_OP_CPU_KERNEL( + pixel_shuffle, + ops::PixelShuffleOpKernel, + ops::PixelShuffleOpKernel); + +REGISTER_OP_CPU_KERNEL( + pixel_shuffle_grad, + ops::PixelShuffleGradOpKernel, + ops::PixelShuffleGradOpKernel); diff --git a/paddle/fluid/platform/dynload/wbaes.cc b/paddle/fluid/operators/pixel_shuffle_op.cu similarity index 55% rename from paddle/fluid/platform/dynload/wbaes.cc rename to paddle/fluid/operators/pixel_shuffle_op.cu index 37387b202aadddef859b0eecca55cb9c99d826ee..6faf91079e1dac00b3516ccde8dc82cec73a79e6 100644 --- a/paddle/fluid/platform/dynload/wbaes.cc +++ b/paddle/fluid/operators/pixel_shuffle_op.cu @@ -12,23 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_WBAES - -#include "paddle/fluid/platform/dynload/wbaes.h" - -namespace paddle { -namespace platform { -namespace dynload { - -std::once_flag wbaes_dso_flag; -void *wbaes_dso_handle = nullptr; - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -WBAES_ROUTINE_EACH(DEFINE_WRAP); - -} // namespace dynload -} // namespace platform -} // namespace paddle - -#endif +#include "paddle/fluid/operators/pixel_shuffle_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + pixel_shuffle, ops::PixelShuffleOpKernel, + ops::PixelShuffleOpKernel); +REGISTER_OP_CUDA_KERNEL( + pixel_shuffle_grad, + ops::PixelShuffleGradOpKernel, + ops::PixelShuffleGradOpKernel); diff --git a/paddle/fluid/operators/pixel_shuffle_op.h b/paddle/fluid/operators/pixel_shuffle_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1ae1c7e9d50cb9d701fd0e79337a1906f2f5d545 --- /dev/null +++ b/paddle/fluid/operators/pixel_shuffle_op.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class PixelShuffleOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int factor = ctx.Attr("upscale_factor"); + + auto in_dims = in->dims(); + auto o_dims = out->dims(); + + framework::Tensor t; + t.ShareDataWith(*in); + t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]}); + + std::vector axis = {0, 1, 4, 2, 5, 3}; + + framework::Tensor o; + o.ShareDataWith(*out); + o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor}); + + math::Transpose trans; + auto& dev_ctx = ctx.template device_context(); + trans(dev_ctx, t, &o, axis); + out->Resize(o_dims); + } +}; + +template +class PixelShuffleGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + int factor = ctx.Attr("upscale_factor"); + + auto do_dims = dout->dims(); + auto dx_dims = dx->dims(); + + framework::Tensor t; + t.ShareDataWith(*dout); + t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor}); + + std::vector axis = {0, 1, 3, 5, 2, 4}; + + framework::Tensor o; + o.ShareDataWith(*dx); + o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]}); + + math::Transpose trans; + auto& dev_ctx = ctx.template device_context(); + trans(dev_ctx, t, &o, axis); + dx->Resize(dx_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 1697343790d13c37d63505acfe471b379bf897d9..07159d4a12ef4b628f7705ed206d3334be46dfc8 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -17,9 +17,6 @@ if (CUPTI_FOUND) endif(CUPTI_FOUND) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) -if (WITH_WBAES) - cc_library(dynload_wbaes SRCS wbaes.cc DEPS dynamic_loader wbaes) -endif() if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml) endif() diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 8ac9393787324d3a8a17ac5a800bcf69638a4fed..15d516836652ea4ea4d1bcdf35022e6b79cc3b52 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -48,8 +48,6 @@ DEFINE_string( DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so."); -DEFINE_string(wbaes_dir, "", "Specify path for loading libwbaes.so."); - namespace paddle { namespace platform { namespace dynload { @@ -248,16 +246,6 @@ void* GetMKLMLDsoHandle() { #endif } -void* GetWBAESDsoHandle() { -#if defined(__APPLE__) || defined(__OSX__) - return GetDsoHandleFromSearchPath(FLAGS_wbaes_dir, "libwbaes.dylib"); -#elif defined(_WIN32) - return GetDsoHandleFromSearchPath(FLAGS_wbaes_dir, "libwbaes.dll"); -#else - return GetDsoHandleFromSearchPath(FLAGS_wbaes_dir, "libwbaes.so"); -#endif -} - } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 5a642967c7666f5d5943214f557786c87491d740..edb4c649addfaf941a00588395d9191038217979 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -32,7 +32,6 @@ void* GetWarpCTCDsoHandle(); void* GetNCCLDsoHandle(); void* GetTensorRtDsoHandle(); void* GetMKLMLDsoHandle(); -void* GetWBAESDsoHandle(); } // namespace dynload } // namespace platform diff --git a/paddle/fluid/platform/dynload/wbaes.h b/paddle/fluid/platform/dynload/wbaes.h deleted file mode 100644 index 22400d44e4ca5568f1d74e4e194e45e81cbdfefe..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/dynload/wbaes.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#ifdef PADDLE_WITH_WBAES - -#include -#include // NOLINT - -#include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include "paddle/fluid/platform/port.h" - -namespace paddle { -namespace platform { -namespace dynload { - -extern std::once_flag wbaes_dso_flag; -extern void *wbaes_dso_handle; - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load wbaes routine - * via operator overloading. - */ - -#define DYNAMIC_LOAD_WBAES_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using wbaesFunc = decltype(&::__name); \ - std::call_once(wbaes_dso_flag, []() { \ - wbaes_dso_handle = paddle::platform::dynload::GetWBAESDsoHandle(); \ - }); \ - static void *p_##__name = dlsym(wbaes_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ - }; \ - extern DynLoad__##__name __name - -#define DECLARE_DYNAMIC_LOAD_WBAES_WRAP(__name) DYNAMIC_LOAD_WBAES_WRAP(__name) - -#define WBAES_ROUTINE_EACH(__macro) __macro(GSECF); - -WBAES_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WBAES_WRAP); - -#undef DYNAMIC_LOAD_WBAES_WRAP - -} // namespace dynload -} // namespace platform -} // namespace paddle - -#endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index c8a0aa58859cca06375ce578e5a7097179e23107..16365c1fd0b0adb914cdfd08e3f6542fca952e06 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer scope_pool - tracer analysis_predictor imperative_profiler) + tracer analysis_predictor imperative_profiler nccl_context) if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e9ed4e16443eba481143bd2095f9970bcb167d71..265707f1bccdabd37b9a7248755d0b81339418c3 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -29,7 +29,7 @@ namespace paddle { namespace pybind { // Bind Methods -void BindTracer(pybind11::module* m) { +void BindImperative(pybind11::module* m) { pybind11::class_(*m, "Tracer", "") .def("__init__", [](imperative::Tracer& self, framework::BlockDesc* root_block) { @@ -59,6 +59,47 @@ void BindTracer(pybind11::module* m) { }) .def("py_trace", &imperative::Tracer::PyTrace, pybind11::return_value_policy::take_ownership); + + // define parallel context + pybind11::class_ parallel_strategy( + *m, "ParallelStrategy", ""); + parallel_strategy.def(pybind11::init()) + .def_property( + "nranks", + [](const imperative::ParallelStrategy& self) { return self.nranks_; }, + [](imperative::ParallelStrategy& self, int nranks) { + self.nranks_ = nranks; + }) + .def_property("local_rank", + [](const imperative::ParallelStrategy& self) { + return self.local_rank_; + }, + [](imperative::ParallelStrategy& self, int local_rank) { + self.local_rank_ = local_rank; + }) + .def_property( + "trainer_endpoints", + [](const imperative::ParallelStrategy& self) { + return self.trainer_endpoints_; + }, + [](imperative::ParallelStrategy& self, std::vector eps) { + self.trainer_endpoints_ = eps; + }) + .def_property("current_endpoint", + [](const imperative::ParallelStrategy& self) { + return self.current_endpoint_; + }, + [](imperative::ParallelStrategy& self, + const std::string& ep) { self.current_endpoint_ = ep; }); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + pybind11::class_ nccl_ctx( + *m, "NCCLParallelContext"); + + nccl_ctx + .def(pybind11::init()) + .def("init", [](imperative::NCCLParallelContext& self) { self.Init(); }); +#endif } } // namespace pybind diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h index 8496cbfcb18798ee8ce1714431b7877bb2b7d377..f9d4a7c990e23b30eb7f5086fe56587f7c38bd22 100644 --- a/paddle/fluid/pybind/imperative.h +++ b/paddle/fluid/pybind/imperative.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/nccl_context.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -46,7 +47,7 @@ class PyVarBase : public imperative::VarBase { using imperative::VarBase::VarBase; // Inherit constructors }; -void BindTracer(pybind11::module* m); +void BindImperative(pybind11::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 044677fb756e0368c65b84f15fdf2540abbd14b8..8c34e3efe2a07cadde5aa06669fda88be7661db1 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -288,7 +288,7 @@ PYBIND11_MODULE(core, m) { }) .def_static("num_funcs", &imperative::PyLayer::NumFuncs); - BindTracer(&m); + BindImperative(&m); py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( diff --git a/python/paddle/distributed/launch.py b/python/paddle/distributed/launch.py index 03c4078775d455fdb19aaf78ace4dcb98c8dd66a..d8153fa00267b00eedc52aa043af9ba7dc090f7d 100644 --- a/python/paddle/distributed/launch.py +++ b/python/paddle/distributed/launch.py @@ -32,6 +32,7 @@ default_envs = { "NCCL_SOCKET_IFNAME": "eth0", "NCCL_IB_GID_INDEX": "3", "NCCL_IB_RETRY_CNT": "0", + "PYTHONPATH": os.getenv("PYTHONPATH", ""), } GPUS = 8 diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 2d0c7b7ddaacee28da599d5850e9b3381c01de5c..9bb72ede304dbde732153bac980f24a74bcd126d 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -29,6 +29,9 @@ from .tracer import * from . import profiler from .profiler import * +from . import parallel +from .parallel import * + from . import checkpoint from .checkpoint import * @@ -41,5 +44,6 @@ __all__ += base.__all__ __all__ += nn.__all__ __all__ += tracer.__all__ __all__ += profiler.__all__ +__all__ += parallel.__all__ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 527c37cb2c4f1540fb8c464dfdbe061b2899f678..71abb9e3eca974138fe2d8bedd41e4d58983f80c 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -48,7 +48,7 @@ class Conv2D(layers.Layer): bias_attr=None, dtype=core.VarDesc.VarType.FP32): assert param_attr is not False, "param_attr should not be False here." - super(Conv2D, self).__init__(name_scope) + super(Conv2D, self).__init__(name_scope, dtype) self._groups = groups self._stride = utils.convert_to_list(stride, 2, 'stride') self._padding = utils.convert_to_list(padding, 2, 'padding') @@ -503,7 +503,7 @@ class FC(layers.Layer): num_flatten_dims=1, dtype=core.VarDesc.VarType.FP32, act=None): - super(FC, self).__init__(name_scope) + super(FC, self).__init__(name_scope, dtype) self._size = size self._num_flatten_dims = num_flatten_dims @@ -608,7 +608,7 @@ class BatchNorm(layers.Layer): do_model_average_for_mean_and_var=False, fuse_with_relu=False, use_global_stats=False): - super(BatchNorm, self).__init__(name_scope) + super(BatchNorm, self).__init__(name_scope, dtype) self._param_attr = param_attr self._param_attr = bias_attr self._act = act @@ -760,7 +760,7 @@ class Embedding(layers.Layer): param_attr=None, dtype='float32'): - super(Embedding, self).__init__(name_scope) + super(Embedding, self).__init__(name_scope, dtype) self._size = size self._is_sparse = is_sparse self._is_distributed = is_distributed @@ -1008,7 +1008,7 @@ class GRUUnit(layers.Layer): gate_activation='sigmoid', origin_mode=False, dtype='float32'): - super(GRUUnit, self).__init__(name_scope) + super(GRUUnit, self).__init__(name_scope, dtype) activation_dict = dict( identity=0, diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..f7decac963f47ba1dcc33e9c8eab7900e745d1df --- /dev/null +++ b/python/paddle/fluid/dygraph/parallel.py @@ -0,0 +1,60 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except jin compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from .. import core + +__all__ = ["prepare_context"] + +ParallelStrategy = core.ParallelStrategy + +__parallel_ctx__clz__ = None + + +def prepare_context(parallel_strategy, place): + global __parallel_ctx__clz__ + assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once." + + if isinstance(place, core.CUDAPlace): + __parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy, + place) + else: + # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation + assert ("Only support CUDAPlace for now.") + __parallel_ctx__clz__.init() + + +class Env(object): + def __init__(self): + self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self._dev_id = int(os.getenv("FLAGS_selected_gpus", "0")) + self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", + "").split(",") + self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "") + + @property + def nranks(self): + return self._nranks + + @property + def local_rank(self): + return self._local_rank + + @property + def dev_id(self): + return self._dev_id + + @property + def current_endpoint(self): + return self._current_endpoint diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 91414fdeb207781afd5e28afa5a3fa6e1018efb1..a5d4d3947ab9b9699f6fc8ac0d7f088ede345290 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -191,6 +191,7 @@ __all__ = [ 'kldiv_loss', 'tree_conv', 'npair_loss', + 'pixel_shuffle', 'fsp_matrix', ] @@ -480,6 +481,8 @@ def dynamic_lstm(input, forward, _ = fluid.layers.dynamic_lstm( input=forward_proj, size=hidden_dim * 4, use_peepholes=False) """ + assert _in_dygraph_mode( + ) is not True, "please use lstm instead of dynamic_lstm in dygraph mode!" assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp." helper = LayerHelper('lstm', **locals()) size = size // 4 @@ -864,6 +867,9 @@ def dynamic_lstmp(input, proj_activation="tanh") """ + assert _in_dygraph_mode( + ) is not True, "please use lstm instead of dynamic_lstmp in dygraph mode!" + assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp." helper = LayerHelper('lstmp', **locals()) size = size // 4 @@ -1035,6 +1041,9 @@ def dynamic_gru(input, hidden = fluid.layers.dynamic_gru(input=x, size=hidden_dim) """ + assert _in_dygraph_mode( + ) is not True, "please use gru instead of dynamic_gru in dygraph mode!" + helper = LayerHelper('gru', **locals()) dtype = helper.input_dtype() @@ -1751,6 +1760,8 @@ def sequence_conv(input, Variable: output of sequence_conv """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_conv', **locals()) dtype = helper.input_dtype() filter_shape = [filter_size * input.shape[1], num_filters] @@ -1810,6 +1821,8 @@ def sequence_softmax(input, use_cudnn=False, name=None): dtype='float32', lod_level=1) x_sequence_softmax = fluid.layers.sequence_softmax(input=x) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_softmax', **locals()) dtype = helper.input_dtype() softmax_out = helper.create_variable_for_type_inference(dtype) @@ -2302,6 +2315,8 @@ def sequence_pool(input, pool_type, is_test=False): last_x = fluid.layers.sequence_pool(input=x, pool_type='last') first_x = fluid.layers.sequence_pool(input=x, pool_type='first') """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_pool', **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -2341,6 +2356,8 @@ def sequence_concat(input, name=None): out = fluid.layers.sequence_concat(input=[seq1, seq2, seq3]) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_concat', **locals()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) helper.append_op( @@ -2468,6 +2485,8 @@ def sequence_slice(input, offset, length, name=None): subseqs = fluid.layers.sequence_slice(input=seqs, offset=offset, length=length) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper("sequence_slice", **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -3927,6 +3946,8 @@ def sequence_expand(x, y, ref_level=-1, name=None): dtype='float32', lod_level=1) out = layers.sequence_expand(x=x, y=y, ref_level=0) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_expand', input=x, **locals()) dtype = helper.input_dtype() tmp = helper.create_variable_for_type_inference(dtype) @@ -3993,6 +4014,8 @@ def sequence_expand_as(x, y, name=None): dtype='float32', lod_level=1) out = layers.sequence_expand_as(x=x, y=y) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_expand_as', input=x, **locals()) dtype = helper.input_dtype() tmp = helper.create_variable_for_type_inference(dtype) @@ -4039,6 +4062,8 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): out = fluid.layers.sequence_pad(x=x, pad_value=pad_value) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_pad', input=x, **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -4105,6 +4130,8 @@ def sequence_unpad(x, length, name=None): out = fluid.layers.sequence_unpad(x=x, length=len) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_unpad', input=x, **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -5278,6 +5305,8 @@ def sequence_reshape(input, new_dim): x = fluid.layers.data(shape=[5, 20], dtype='float32', lod_level=1) x_reshaped = fluid.layers.sequence_reshape(input=x, new_dim=10) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_reshape', **locals()) out = helper.create_variable_for_type_inference(helper.input_dtype()) helper.append_op( @@ -5812,6 +5841,8 @@ def im2sequence(input, input=layer, stride=[1, 1], filter_size=[2, 2]) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") if isinstance(filter_size, int): filter_size = [filter_size, filter_size] @@ -6228,7 +6259,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): }, outputs={'Diff': diff, 'Out': loss}, - attrs={'sigma': sigma}) + attrs={'sigma': sigma if sigma is not None else 1.0}) return loss @@ -7589,6 +7620,8 @@ def sequence_scatter(input, index, updates, name=None): output = fluid.layers.sequence_scatter(input, index, updates) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_scatter', **locals()) dtype = helper.input_dtype() out = helper.create_variable_for_type_inference(dtype) @@ -8677,6 +8710,8 @@ def sequence_enumerate(input, win_size, pad_value=0, name=None): x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1) out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_enumerate', **locals()) out = helper.create_variable_for_type_inference( helper.input_dtype(), stop_gradient=True) @@ -8716,6 +8751,8 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): Variable: The output sequence mask. """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper('sequence_mask', **locals()) if name is None: @@ -9766,6 +9803,8 @@ def sequence_reverse(x, name=None): Returns: out(${y_type}): ${y_comment} """ + assert not _in_dygraph_mode(), ( + "sequence layer is not supported in dygraph mode yet.") helper = LayerHelper("sequence_reverse", **locals()) if name is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -10923,6 +10962,65 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): return l2loss + celoss +def pixel_shuffle(x, upscale_factor): + """ + + **Pixel Shuffle Layer** + + This layer rearranges elements in a tensor of shape [N, C, H, W] + to a tensor of shape [N, C/r**2, H*r, W*r]. + This is useful for implementing efficient sub-pixel convolution + with a stride of 1/r. + Please refer to the paper: `Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network `_ . + by Shi et. al (2016) for more details. + + .. code-block:: text + + Given a 4-D tensor with the shape: + x.shape = [1, 9, 4, 4] + Given upscale_factor: + upscale_factor= 3 + output shape is: + [1, 1, 12, 12] + + Args: + + x(Variable): The input tensor variable. + upscale_factor(int): factor to increase spatial resolution + + Returns: + + Out(Variable): the pixel shuffle result is a tensor variable with the same shape and the same type as the input. + + Raises: + + ValueError: If the square of upscale_factor cannot divide the channels of input. + + Examples: + + .. code-block:: python + + input = fluid.layers.data(shape=[9,4,4]) + output = fluid.layers.pixel_shuffle(x=input, upscale_factor=3) + + """ + + helper = LayerHelper("pixel_shuffle", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if not isinstance(upscale_factor, int): + raise TypeError("upscale factor must be int type") + + helper.append_op( + type="pixel_shuffle", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"upscale_factor": upscale_factor}) + return out + + def fsp_matrix(x, y): """ diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_conditional_block.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_conditional_block.py new file mode 100644 index 0000000000000000000000000000000000000000..95cae1c2029c472c5a34b37a79739e2ff088feb2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_conditional_block.py @@ -0,0 +1,23 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest + +fluid.core._set_eager_deletion_mode(0.0, 1.0, True) + +from test_conditional_block import * + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py index 732f0681c4e65006628d51e083a400c0b5bd3d92..89ae3c6a39d6277f590c8f2e02f7b0ae62a1cd4a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py @@ -302,8 +302,11 @@ use_py_reader = False # if we run sync mode sync = False -# how many batches we use -batch_num = 50 +if not core.is_compiled_with_cuda(): + # how many batches we use + batch_num = 50 +else: + batch_num = 5 np.random.seed = 1 src_word_np = np.random.randint( diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 89ba17f8b940d8f34e0df3ce6980ce7ddced607c..98b39256aad8435a1b54fe11fbb8d4677f18e99c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -18,6 +18,8 @@ import unittest import contextlib import numpy as np import decorators +import inspect +from six.moves import filter import paddle import paddle.fluid as fluid @@ -58,8 +60,12 @@ class LayerTest(unittest.TestCase): fluid.default_main_program().random_seed = self.seed yield - def get_static_graph_result(self, feed, fetch_list, with_lod=False): - exe = fluid.Executor(self._get_place()) + def get_static_graph_result(self, + feed, + fetch_list, + with_lod=False, + force_to_use_cpu=False): + exe = fluid.Executor(self._get_place(force_to_use_cpu)) exe.run(fluid.default_startup_program()) return exe.run(fluid.default_main_program(), feed=feed, @@ -77,7 +83,6 @@ class LayerTest(unittest.TestCase): class TestLayer(LayerTest): def test_fc(self): - # pdb.set_trace() inp = np.ones([3, 32, 32], dtype='float32') with self.static_graph(): t = layers.data( @@ -870,25 +875,102 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(dy_rlt._numpy(), static_rlt)) -class TestBook(unittest.TestCase): - def test_fit_a_line(self): - program = Program() - with program_guard(program, startup_program=Program()): - x = layers.data(name='x', shape=[13], dtype='float32') +class TestBook(LayerTest): + def test_all_layers(self): + attrs = (getattr(self, name) for name in dir(self)) + methods = filter(inspect.ismethod, attrs) + for method in methods: + if not method.__name__.startswith('make_'): + continue + self._low_data_bound = 0 + self._high_data_bound = 2 + self._batch_size = 2 + self._feed_dict = {} + self._force_to_use_cpu = False + with self.static_graph(): + static_var = method() + if isinstance(static_var, tuple): + static_var = static_var[0] + + if static_var is not None: + fetch_list = [static_var.name] + static_result = self.get_static_graph_result( + feed=self._feed_dict, + fetch_list=fetch_list, + force_to_use_cpu=self._force_to_use_cpu) + else: + assert method.__name__ in ('make_get_places') + continue + + with self.dynamic_graph(self._force_to_use_cpu): + dy_result = method() + if isinstance(dy_result, tuple): + dy_result = dy_result[0] + + self.assertTrue(np.array_equal(static_result[0], dy_result._numpy())) + + def _get_np_data(self, shape, dtype, append_batch_size=True): + np.random.seed(self.seed) + if append_batch_size: + shape = [self._batch_size] + shape + if dtype == 'float32': + return np.random.random(shape).astype(dtype) + elif dtype == 'float64': + return np.random.random(shape).astype(dtype) + elif dtype == 'int32': + return np.random.randint(self._low_data_bound, + self._high_data_bound, shape).astype(dtype) + elif dtype == 'int64': + return np.random.randint(self._low_data_bound, + self._high_data_bound, shape).astype(dtype) + + def _get_data(self, + name, + shape, + dtype, + set_feed_dict=True, + append_batch_size=True): + if base.enabled(): + return base.to_variable( + value=self._get_np_data(shape, dtype, append_batch_size), + name=name) + else: + if set_feed_dict: + self._feed_dict[name] = self._get_np_data(shape, dtype, + append_batch_size) + return layers.data( + name=name, + shape=shape, + dtype=dtype, + append_batch_size=append_batch_size) + + def make_sampled_softmax_with_cross_entropy(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + logits = self._get_data(name='Logits', shape=[256], dtype='float32') + label = self._get_data(name='Label', shape=[1], dtype='int64') + num_samples = 25 + output = layers.sampled_softmax_with_cross_entropy(logits, label, + num_samples) + return (output) + + def make_fit_a_line(self): + with program_guard( + fluid.default_main_program(), + startup_program=fluid.default_startup_program()): + x = self._get_data(name='x', shape=[13], dtype='float32') y_predict = layers.fc(input=x, size=1, act=None) - y = layers.data(name='y', shape=[1], dtype='float32') + y = self._get_data(name='y', shape=[1], dtype='float32') cost = layers.square_error_cost(input=y_predict, label=y) avg_cost = layers.mean(cost) - self.assertIsNotNone(avg_cost) + return (avg_cost) - print(str(program)) - - def test_recognize_digits_mlp(self): - program = Program() - with program_guard(program, startup_program=Program()): + def make_recognize_digits_mlp(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): # Change g_program, so the rest layers use `g_program` - images = layers.data(name='pixel', shape=[784], dtype='float32') - label = layers.data(name='label', shape=[1], dtype='int32') + images = self._get_data(name='pixel', shape=[784], dtype='float32') + label = self._get_data(name='label', shape=[1], dtype='int64') hidden1 = layers.fc(input=images, size=128, act='relu') hidden2 = layers.fc(input=hidden1, size=64, act='relu') predict = layers.fc(input=[hidden2, hidden1], @@ -897,32 +979,21 @@ class TestBook(unittest.TestCase): param_attr=["sftmax.w1", "sftmax.w2"]) cost = layers.cross_entropy(input=predict, label=label) avg_cost = layers.mean(cost) - self.assertIsNotNone(avg_cost) + return (avg_cost) - print(str(program)) - - def test_simple_conv2d(self): - program = Program() - with program_guard(program, startup_program=Program()): - images = layers.data( - name='pixel', shape=[3, 48, 48], dtype='float32') - layers.conv2d(input=images, num_filters=3, filter_size=[4, 4]) - - print(str(program)) - - def test_conv2d_transpose(self): - program = Program() - with program_guard(program): - img = layers.data(name='pixel', shape=[3, 2, 2], dtype='float32') - layers.conv2d_transpose(input=img, num_filters=10, output_size=28) - print(str(program)) + def make_conv2d_transpose(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + img = self._get_data(name='pixel', shape=[3, 2, 2], dtype='float32') + return layers.conv2d_transpose( + input=img, num_filters=10, output_size=28) - def test_recognize_digits_conv(self): - program = Program() - with program_guard(program, startup_program=Program()): - images = layers.data( + def make_recognize_digits_conv(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + images = self._get_data( name='pixel', shape=[1, 28, 28], dtype='float32') - label = layers.data(name='label', shape=[1], dtype='int32') + label = self._get_data(name='label', shape=[1], dtype='int64') conv_pool_1 = nets.simple_img_conv_pool( input=images, filter_size=5, @@ -941,19 +1012,19 @@ class TestBook(unittest.TestCase): predict = layers.fc(input=conv_pool_2, size=10, act="softmax") cost = layers.cross_entropy(input=predict, label=label) avg_cost = layers.mean(cost) + return avg_cost - print(str(program)) - - def test_word_embedding(self): - program = Program() - with program_guard(program, startup_program=Program()): + def make_word_embedding(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): dict_size = 10000 embed_size = 32 - first_word = layers.data(name='firstw', shape=[1], dtype='int64') - second_word = layers.data(name='secondw', shape=[1], dtype='int64') - third_word = layers.data(name='thirdw', shape=[1], dtype='int64') - forth_word = layers.data(name='forthw', shape=[1], dtype='int64') - next_word = layers.data(name='nextw', shape=[1], dtype='int64') + first_word = self._get_data(name='firstw', shape=[1], dtype='int64') + second_word = self._get_data( + name='secondw', shape=[1], dtype='int64') + third_word = self._get_data(name='thirdw', shape=[1], dtype='int64') + forth_word = self._get_data(name='forthw', shape=[1], dtype='int64') + next_word = self._get_data(name='nextw', shape=[1], dtype='int64') embed_first = layers.embedding( input=first_word, @@ -987,257 +1058,126 @@ class TestBook(unittest.TestCase): act='softmax') cost = layers.cross_entropy(input=predict_word, label=next_word) avg_cost = layers.mean(cost) - self.assertIsNotNone(avg_cost) + return (avg_cost) - print(str(program)) - - def test_linear_chain_crf(self): - program = Program() - with program_guard(program, startup_program=Program()): - label_dict_len = 10 - images = layers.data(name='pixel', shape=[784], dtype='float32') - label = layers.data(name='label', shape=[1], dtype='int32') - hidden = layers.fc(input=images, size=128) - crf = layers.linear_chain_crf( - input=hidden, label=label, param_attr=ParamAttr(name="crfw")) - crf_decode = layers.crf_decoding( - input=hidden, param_attr=ParamAttr(name="crfw")) - layers.chunk_eval( - input=crf_decode, - label=label, - chunk_scheme="IOB", - num_chunk_types=(label_dict_len - 1) // 2) - self.assertFalse(crf is None) - self.assertFalse(crf_decode is None) - - print(str(program)) - - def test_sigmoid_cross_entropy(self): - program = Program() - with program_guard(program): - dat = layers.data(name='data', shape=[10], dtype='float32') - lbl = layers.data(name='label', shape=[10], dtype='float32') + def make_sigmoid_cross_entropy(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + dat = self._get_data(name='data', shape=[10], dtype='float32') + lbl = self._get_data(name='label', shape=[10], dtype='float32') ignore_index = -1 - self.assertIsNotNone( - layers.sigmoid_cross_entropy_with_logits( - x=dat, label=lbl, ignore_index=ignore_index)) - print(str(program)) - - def test_hsigmoid(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[2], dtype='float32') - y = layers.data(name='y', shape=[2], dtype='int64') - self.assertIsNotNone( - layers.hsigmoid( - input=x, label=y, num_classes=2)) - print(str(program)) + return (layers.sigmoid_cross_entropy_with_logits( + x=dat, label=lbl, ignore_index=ignore_index)) + + def make_hsigmoid(self): + self._force_to_use_cpu = True + with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()): + x = self._get_data(name='x', shape=[2], dtype='float32') + y = self._get_data(name='y', shape=[2], dtype='int64') + return (layers.hsigmoid(input=x, label=y, num_classes=2)) # test hsigmod with custom tree structure program2 = Program() with program_guard(program2): - x2 = layers.data(name='x2', shape=[4, 8], dtype='float32') - y2 = layers.data(name='y2', shape=[4], dtype='int64') - path_table = layers.data( + x2 = self._get_data(name='x2', shape=[4, 8], dtype='float32') + y2 = self._get_data(name='y2', shape=[4], dtype='int64') + path_table = self._get_data( name='path_table', shape=[4, 6], dtype='int64') - path_code = layers.data( + path_code = self._get_data( name='path_code', shape=[4, 6], dtype='int64') - self.assertIsNotNone( - layers.hsigmoid( - input=x2, - label=y2, - num_classes=6, - path_table=path_table, - path_code=path_code, - is_custom=True)) - print(str(program2)) - - def test_sequence_expand(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[10], dtype='float32') - y = layers.data( - name='y', shape=[10, 20], dtype='float32', lod_level=2) - self.assertIsNotNone(layers.sequence_expand(x=x, y=y, ref_level=1)) - print(str(program)) - - def test_sequence_unpad(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[10, 5], dtype='float32') - length = layers.data(name='length', shape=[1], dtype='int64') - self.assertIsNotNone(layers.sequence_unpad(x=x, length=length)) - print(str(program)) - - def test_pool2d(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 224, 224], dtype='float32') - self.assertIsNotNone( - layers.pool2d( - x, - pool_size=[5, 3], - pool_stride=[1, 2], - pool_padding=(2, 1))) - - def test_adaptive_pool2d(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 224, 224], dtype='float32') - self.assertIsNotNone( - layers.adaptive_pool2d( - x, [3, 3], pool_type='avg')) + return (layers.hsigmoid( + input=x2, + label=y2, + num_classes=6, + path_table=path_table, + path_code=path_code, + is_custom=True)) + + def make_pool2d(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 224, 224], dtype='float32') + return (layers.pool2d( + x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1))) + + def make_adaptive_pool2d(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 224, 224], dtype='float32') + return (layers.adaptive_pool2d(x, [3, 3], pool_type='avg')) pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True) - self.assertIsNotNone(pool) - self.assertIsNotNone(mask) - self.assertIsNotNone(layers.adaptive_pool2d(x, 3, pool_type='avg')) + return (pool) + return (mask) + return (layers.adaptive_pool2d(x, 3, pool_type='avg')) pool, mask = layers.adaptive_pool2d(x, 3, require_index=True) - self.assertIsNotNone(pool) - self.assertIsNotNone(mask) - - def test_adaptive_pool3d(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 244, 224, 224], dtype='float32') - self.assertIsNotNone( - layers.adaptive_pool3d( - x, [3, 3, 3], pool_type='avg')) + return (pool) + return (mask) + + def make_adaptive_pool3d(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( + name='x', shape=[3, 244, 224, 224], dtype='float32') + return (layers.adaptive_pool3d(x, [3, 3, 3], pool_type='avg')) pool, mask = layers.adaptive_pool3d( x, [3, 3, 3], require_index=True) - self.assertIsNotNone(pool) - self.assertIsNotNone(mask) - self.assertIsNotNone(layers.adaptive_pool3d(x, 3, pool_type='avg')) + return (pool) + return (mask) + return (layers.adaptive_pool3d(x, 3, pool_type='avg')) pool, mask = layers.adaptive_pool3d(x, 3, require_index=True) - self.assertIsNotNone(pool) - self.assertIsNotNone(mask) + return (pool) + return (mask) - def test_lstm_unit(self): - program = Program() - with program_guard(program): - x_t_data = layers.data( + def make_lstm_unit(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x_t_data = self._get_data( name='x_t_data', shape=[10, 10], dtype='float32') x_t = layers.fc(input=x_t_data, size=10) - prev_hidden_data = layers.data( + prev_hidden_data = self._get_data( name='prev_hidden_data', shape=[10, 30], dtype='float32') prev_hidden = layers.fc(input=prev_hidden_data, size=30) - prev_cell_data = layers.data( + prev_cell_data = self._get_data( name='prev_cell', shape=[10, 30], dtype='float32') prev_cell = layers.fc(input=prev_cell_data, size=30) - self.assertIsNotNone( - layers.lstm_unit( - x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) - print(str(program)) - - def test_dynamic_lstmp(self): - program = Program() - with program_guard(program): - hidden_dim, proj_dim = 16, 8 - seq_data = layers.data( - name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) - fc_out = layers.fc(input=seq_data, size=4 * hidden_dim) - self.assertIsNotNone( - layers.dynamic_lstmp( - input=fc_out, size=4 * hidden_dim, proj_size=proj_dim)) - print(str(program)) + return (layers.lstm_unit( + x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) - def test_sequence_softmax(self): - program = Program() - with program_guard(program): - seq_data = layers.data( - name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) - seq = layers.fc(input=seq_data, size=20) - self.assertIsNotNone(layers.sequence_softmax(seq)) - print(str(program)) - - def test_softmax(self): - program = Program() - with program_guard(program): - data = layers.data(name='data', shape=[10], dtype='float32') + def make_softmax(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name='data', shape=[10], dtype='float32') hid = layers.fc(input=data, size=20) - self.assertIsNotNone(layers.softmax(hid, axis=1)) - print(str(program)) + return (layers.softmax(hid, axis=1)) - def test_space_to_depth(self): - program = Program() - with program_guard(program): - data = layers.data( + def make_space_to_depth(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data( name='data', shape=[32, 9, 6, 6], append_batch_size=False, dtype='float32') - self.assertIsNotNone(layers.space_to_depth(data, 3)) - print(str(program)) + return (layers.space_to_depth(data, 3)) - def test_sequence_unsqueeze(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[8, 2], dtype='float32') - out = layers.unsqueeze(input=x, axes=[1]) - self.assertIsNotNone(out) - print(str(program)) + def make_lrn(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name='data', shape=[6, 2, 2], dtype='float32') + return (layers.lrn(data)) - def test_squeeze(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[1, 1, 4], dtype='float32') - out = layers.squeeze(input=x, axes=[2]) - self.assertIsNotNone(out) - print(str(program)) - - def test_lrn(self): - program = Program() - with program_guard(program): - data = layers.data(name='data', shape=[6, 2, 2], dtype='float32') - self.assertIsNotNone(layers.lrn(data)) - print(str(program)) - - def test_get_places(self): - program = Program() - with program_guard(program): - x = get_places(device_count=4) - self.assertIsNotNone(x) - print(str(program)) - - def test_sequence_reshape(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[8], dtype='float32', lod_level=1) - out = layers.sequence_reshape(input=x, new_dim=16) - self.assertIsNotNone(out) - print(str(program)) - - def test_im2sequence(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 128, 128], dtype='float32') - y = layers.data(name='y', shape=[], dtype='float32') - output = layers.im2sequence( - input=x, - input_image_size=y, - stride=[1, 1], - filter_size=[2, 2], - out_stride=[1, 1]) - self.assertIsNotNone(output) - print(str(program)) - - def test_sampled_softmax_with_cross_entropy(self): - program = Program() - with program_guard(program): - logits = layers.data(name='Logits', shape=[256], dtype='float64') - label = layers.data(name='Label', shape=[1], dtype='int64') - num_samples = 25 - output = layers.sampled_softmax_with_cross_entropy(logits, label, - num_samples) - self.assertIsNotNone(output) - print(str(program)) + def make_get_places(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + get_places(device_count=1) @decorators.prog_scope() - def test_nce(self): + def make_nce(self): window_size = 5 words = [] for i in range(window_size): words.append( - layers.data( + self._get_data( name='word_{0}'.format(i), shape=[1], dtype='int64')) dict_size = 10000 @@ -1263,278 +1203,168 @@ class TestBook(unittest.TestCase): param_attr='nce.w', bias_attr='nce.b') avg_loss = layers.mean(loss) - self.assertIsNotNone(avg_loss) - print(str(default_main_program())) - - def test_row_conv(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[16], dtype='float32', lod_level=1) - out = layers.row_conv(input=x, future_context_size=2) - self.assertIsNotNone(out) - print(str(program)) - - def test_multiplex(self): - program = Program() - with program_guard(program): - x1 = layers.data(name='x1', shape=[4], dtype='float32') - x2 = layers.data(name='x2', shape=[4], dtype='float32') - index = layers.data(name='index', shape=[1], dtype='int32') + return (avg_loss) + + def make_multiplex(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x1 = self._get_data(name='x1', shape=[4], dtype='float32') + x2 = self._get_data(name='x2', shape=[4], dtype='float32') + index = self._get_data(name='index', shape=[1], dtype='int32') out = layers.multiplex(inputs=[x1, x2], index=index) - self.assertIsNotNone(out) - print(str(program)) - - def test_softmax_with_cross_entropy(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[16], dtype='float32') - y = layers.data(name='label', shape=[1], dtype='int64') + return (out) + + def make_softmax_with_cross_entropy(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[16], dtype='float32') + y = self._get_data(name='label', shape=[1], dtype='int64') loss, softmax = layers.softmax_with_cross_entropy( x, y, return_softmax=True) - self.assertIsNotNone(loss) - self.assertIsNotNone(softmax) + return (loss) + return (softmax) loss = layers.softmax_with_cross_entropy(x, y) - self.assertIsNotNone(loss) - print(str(program)) - - def test_smooth_l1(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[4], dtype='float32') - y = layers.data(name='label', shape=[4], dtype='float32') + return (loss) + + def make_smooth_l1(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[4], dtype='float32') + y = self._get_data(name='label', shape=[4], dtype='float32') loss = layers.smooth_l1(x, y) - self.assertIsNotNone(loss) - print(str(program)) + return (loss) - def test_scatter(self): - program = Program() - with program_guard(program): - x = layers.data( + def make_scatter(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( name='x', shape=[3, 3], append_batch_size=False, dtype='float32') - idx = layers.data( + idx = self._get_data( name='idx', shape=[2], append_batch_size=False, dtype='int32') - updates = layers.data( + updates = self._get_data( name='updates', shape=[2, 3], append_batch_size=False, dtype='float32') out = layers.scatter(input=x, index=idx, updates=updates) - self.assertIsNotNone(out) - print(str(program)) - - def test_sequence_scatter(self): - program = Program() - with program_guard(program): - x = layers.data( - name='x', - shape=[3, 6], - append_batch_size=False, - dtype='float32') - idx = layers.data( - name='idx', - shape=[12, 1], - append_batch_size=False, - dtype='int32', - lod_level=1) - updates = layers.data( - name='updates', - shape=[12, 1], - append_batch_size=False, - dtype='float32', - lod_level=1) - out = layers.sequence_scatter(input=x, index=idx, updates=updates) - self.assertIsNotNone(out) - print(str(program)) - - def test_sequence_slice(self): - program = Program() - with program_guard(program): - import numpy as np - seqs = layers.data( - name='x', shape=[10, 5], dtype='float32', lod_level=1) - offset = layers.assign(input=np.array([[0, 1]]).astype('int32')) - length = layers.assign(input=np.array([[2, 1]]).astype('int32')) - out = layers.sequence_slice( - input=seqs, offset=offset, length=length) - self.assertIsNotNone(out) - print(str(program)) - - def test_lod_reset(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[10], dtype='float32') - y = layers.data( - name='y', shape=[10, 20], dtype='float32', lod_level=2) - print(layers.lod_reset(x=x, y=y)) - print(str(program)) + return (out) - def test_label_smooth(self): - program = Program() - with program_guard(program): - label = layers.data(name="label", shape=[1], dtype="float32") + def make_label_smooth(self): + # TODO(minqiyang): support gpu ut + self._force_to_use_cpu = True + with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()): + label = self._get_data(name="label", shape=[1], dtype="int32") one_hot_label = layers.one_hot(input=label, depth=10) smooth_label = layers.label_smooth( - label=one_hot_label, epsilon=0.1, dtype="float32") - self.assertIsNotNone(smooth_label) - print(str(program)) - - def test_topk(self): - program = Program() - with program_guard(program): - data = layers.data(name="label", shape=[200], dtype="float32") - values, indices = layers.topk(data, k=5) - self.assertIsNotNone(values) - self.assertIsNotNone(indices) - print(str(program)) - - def test_roi_pool(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") - rois = layers.data( - name="rois", shape=[4], dtype="float32", lod_level=1) - output = layers.roi_pool(x, rois, 7, 7, 0.6) - self.assertIsNotNone(output) - print(str(program)) + label=one_hot_label, epsilon=0.1, dtype="int32") + return (smooth_label) - def test_psroi_pool(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[245, 30, 30], dtype="float32") - rois = layers.data( - name="rois", shape=[4], dtype="float32", lod_level=1) - output = layers.psroi_pool(x, rois, 5, 0.25, 7, 7) - self.assertIsNotNone(output) - print(str(program)) - - def test_roi_align(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") - rois = layers.data( - name="rois", shape=[4], dtype="float32", lod_level=1) - output = layers.roi_align(x, rois, 14, 14, 0.5, 2) - self.assertIsNotNone(output) - print(str(program)) + def make_topk(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name="label", shape=[200], dtype="float32") + values, indices = layers.topk(data, k=5) + return (values) + return (indices) - def test_resize_bilinear(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 9, 6], dtype="float32") + def make_resize_bilinear(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") output = layers.resize_bilinear(x, out_shape=[12, 12]) - self.assertIsNotNone(output) + return (output) output = layers.resize_bilinear(x, scale=3) - self.assertIsNotNone(output) - print(str(program)) + return (output) - def test_resize_nearest(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 9, 6], dtype="float32") + def make_resize_nearest(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 9, 6], dtype="float32") output = layers.resize_nearest(x, out_shape=[12, 12]) - self.assertIsNotNone(output) + return (output) output = layers.resize_nearest(x, scale=3) - self.assertIsNotNone(output) - print(str(program)) + return (output) - def test_polygon_box_transform(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[8, 4, 4], dtype="float32") + def make_polygon_box_transform(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[8, 4, 4], dtype="float32") output = layers.polygon_box_transform(input=x) - self.assertIsNotNone(output) - print(str(program)) + return (output) - def test_l2_normalize(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[8, 7, 10], dtype="float32") + def make_l2_normalize(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[8, 7, 10], dtype="float32") output = layers.l2_normalize(x, axis=1) + return output - def test_maxout(self): - program = Program() - with program_guard(program): - data = layers.data(name='x', shape=[8, 6, 6], dtype="float32") + def make_maxout(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name='x', shape=[8, 6, 6], dtype="float32") output = layers.maxout(x=data, groups=2) - self.assertIsNotNone(output) - print(str(program)) - - def test_crop(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 5], dtype="float32") - y = layers.data(name='y', shape=[2, 3], dtype="float32") + return (output) + + def make_crop(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 5], dtype="float32") + y = self._get_data(name='y', shape=[2, 3], dtype="float32") output = layers.crop(x, shape=y) - self.assertIsNotNone(output) - print(str(program)) - - def test_mean_iou(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[16], dtype='float32') - y = layers.data(name='label', shape=[1], dtype='int64') - iou = layers.mean_iou(x, y, 2) - self.assertIsNotNone(iou) - print(str(program)) - - def test_argsort(self): - program = Program() - with program_guard(program): - data = layers.data(name='x', shape=[2, 3, 3], dtype="float32") + return (output) + + def make_mean_iou(self): + with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()): + x = self._get_data(name='x', shape=[16], dtype='int32') + y = self._get_data(name='label', shape=[16], dtype='int32') + iou = layers.mean_iou(x, y, self._high_data_bound) + return (iou) + + def make_argsort(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name='x', shape=[2, 3, 3], dtype="float32") out, ids = layers.argsort(input=data, axis=1) - self.assertIsNotNone(out) - self.assertIsNotNone(ids) - print(str(program)) - - def test_rank_loss(self): - program = Program() - with program_guard(program): - label = layers.data( + return (out) + return (ids) + + def make_rank_loss(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + label = self._get_data( name='label', append_batch_size=False, shape=[16, 1], dtype="float32") - left = layers.data( + left = self._get_data( name='left', append_batch_size=False, shape=[16, 1], dtype="float32") - right = layers.data( + right = self._get_data( name='right', append_batch_size=False, shape=[16, 1], dtype="float32") out = layers.rank_loss(label, left, right, name="rank_loss") - self.assertIsNotNone(out) - print(str(program)) - - def test_flatten(self): - program = Program() - with program_guard(program): - x = layers.data( - name='x', - append_batch_size=False, - shape=[4, 4, 3], - dtype="float32") - out = layers.flatten(x, axis=1, name="flatten") - self.assertIsNotNone(out) + return (out) - def test_shape(self): - program = Program() - with program_guard(program): - input = layers.data( + def make_shape(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( name="input", shape=[3, 100, 100], dtype="float32") out = layers.shape(input) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_pad2d(self): - program = Program() - with program_guard(program): - input = layers.data( + def make_pad2d(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( name="input", shape=[3, 100, 100], dtype="float32") paddings = layers.fill_constant(shape=[4], dtype='int32', value=1) out = layers.pad2d( @@ -1549,14 +1379,13 @@ class TestBook(unittest.TestCase): mode='reflect', data_format='NCHW', name="shape") - self.assertIsNotNone(out) - self.assertIsNotNone(out_1) - print(str(program)) + return (out) + return (out_1) - def test_prelu(self): - program = Program() - with program_guard(program): - input = layers.data( + def make_prelu(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( name="input", shape=[5, 200, 100, 100], dtype="float32") mode = 'channel' out = layers.prelu( @@ -1564,291 +1393,379 @@ class TestBook(unittest.TestCase): mode, param_attr=ParamAttr(initializer=Constant(1.0)), name='prelu') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_brelu(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_brelu(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_leaky_relu(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_leaky_relu(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.leaky_relu(input, alpha=0.1, name='leaky_relu') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_soft_relu(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_soft_relu(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.soft_relu(input, threshold=30.0, name='soft_relu') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_sigmoid(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_sigmoid(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.sigmoid(input, name='sigmoid') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_logsigmoid(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_logsigmoid(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.logsigmoid(input, name='logsigmoid') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_exp(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_exp(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.exp(input, name='exp') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_tanh(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_tanh(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.tanh(input, name='tanh') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_tanh_shrink(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_tanh_shrink(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.tanh_shrink(input, name='tanh_shrink') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_sqrt(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_sqrt(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.sqrt(input, name='sqrt') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_abs(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_abs(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.abs(input, name='abs') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_ceil(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_ceil(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.ceil(input, name='ceil') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_floor(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_floor(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.floor(input, name='floor') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_cos(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_cos(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.cos(input, name='cos') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_sin(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_sin(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.sin(input, name='sin') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_round(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_round(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.round(input, name='round') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_reciprocal(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_reciprocal(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.reciprocal(input, name='reciprocal') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_square(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_square(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.square(input, name='square') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_softplus(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_softplus(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.softplus(input, name='softplus') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_softsign(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_softsign(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.softsign(input, name='softsign') - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_roi_perspective_transform(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") - rois = layers.data( - name="rois", shape=[8], dtype="float32", lod_level=1) - output = layers.roi_perspective_transform(x, rois, 7, 7, 0.6) - self.assertIsNotNone(output) - print(str(program)) - - def test_sequence_enumerate(self): - program = Program() - with program_guard(program): - x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1) - out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) - print(str(program)) - - def test_cross_entropy(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[30, 10], dtype="float32") - label = layers.data(name="label", shape=[30, 1], dtype="int32") + def make_cross_entropy(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="x", shape=[30, 10], dtype="float32") + label = self._get_data(name="label", shape=[30, 1], dtype="int64") mode = 'channel' out = layers.cross_entropy(x, label, False, 4) - self.assertIsNotNone(out) + return (out) - def test_bpr_loss(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[30, 10], dtype="float32") - label = layers.data(name="label", shape=[30, 1], dtype="int32") + def make_bpr_loss(self): + self._force_to_use_cpu = True + with fluid.framework._dygraph_place_guard(place=fluid.CPUPlace()): + x = self._get_data(name="x", shape=[30, 10], dtype="float32") + label = self._get_data(name="label", shape=[30, 1], dtype="int64") out = layers.bpr_loss(x, label) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_expand(self): - program = Program() - with program_guard(program): - x = layers.data(name="input", shape=[10], dtype='int32') + def make_expand(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="input", shape=[10], dtype='int32') out = layers.expand(x, [1, 2]) - print(str(program)) + return out - def test_uniform_random_batch_size_like(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[13, 11], dtype='float32') + def make_uniform_random_batch_size_like(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( + name="input", shape=[13, 11], dtype='float32') out = layers.uniform_random_batch_size_like(input, [-1, 11]) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_gaussian_random(self): - program = Program() - with program_guard(program): + def make_gaussian_random(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): out = layers.gaussian_random(shape=[20, 30]) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_sampling_id(self): - program = Program() - with program_guard(program): - x = layers.data( + def make_sampling_id(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( name="X", shape=[13, 11], dtype='float32', append_batch_size=False) out = layers.sampling_id(x) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_gaussian_random_batch_size_like(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[13, 11], dtype='float32') + def make_gaussian_random_batch_size_like(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( + name="input", shape=[13, 11], dtype='float32') out = layers.gaussian_random_batch_size_like( input, shape=[-1, 11], mean=1.0, std=2.0) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_sum(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[13, 11], dtype='float32') + def make_sum(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( + name="input", shape=[13, 11], dtype='float32') out = layers.sum(input) - self.assertIsNotNone(out) - print(str(program)) + return (out) - def test_slice(self): + def make_slice(self): starts = [1, 0, 2] ends = [3, 3, 4] axes = [0, 1, 2] - program = Program() - with program_guard(program): - input = layers.data( + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data( name="input", shape=[3, 4, 5, 6], dtype='float32') out = layers.slice(input, axes=axes, starts=starts, ends=ends) + return out - def test_softshrink(self): - program = Program() - with program_guard(program): - input = layers.data(name="input", shape=[16], dtype="float32") + def make_softshrink(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = self._get_data(name="input", shape=[16], dtype="float32") out = layers.softshrink(input, name='softshrink') - self.assertIsNotNone(out) - print(str(program)) - - def iou_similarity(self): - program = Program() - with program_guard(program): - x = layers.data(name="x", shape=[16], dtype="float32") - y = layers.data(name="y", shape=[16], dtype="float32") + return (out) + + def make_iou_similarity(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="x", shape=[4], dtype="float32") + y = self._get_data(name="y", shape=[4], dtype="float32") out = layers.iou_similarity(x, y, name='iou_similarity') - self.assertIsNotNone(out) - print(str(program)) - - def test_grid_sampler(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[3, 5, 7], dtype='float32') - grid = layers.data(name='grid', shape=[5, 7, 2], dtype='float32') + return (out) + + def make_grid_sampler(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name='x', shape=[3, 5, 7], dtype='float32') + grid = self._get_data(name='grid', shape=[5, 7, 2], dtype='float32') out = layers.grid_sampler(x, grid) - self.assertIsNotNone(out) - print(str(program)) + return (out) + + def make_bilinear_tensor_product_layer(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data(name='data', shape=[4], dtype="float32") + + theta = self._get_data(name="theta", shape=[5], dtype="float32") + out = layers.bilinear_tensor_product(data, theta, 6) + return (out) + + def make_batch_norm(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data( + name='data', shape=[32, 128, 128], dtype="float32") + out = layers.batch_norm(data) + return (out) + + def make_range(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + layers.range(0, 10, 2, 'int32') + y = layers.range(0.1, 10.0, 0.2, 'float32') + return y + + def make_spectral_norm(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + weight = self._get_data( + name='weight', + shape=[2, 3, 32, 32], + dtype="float32", + append_batch_size=False) + out = layers.spectral_norm(weight, dim=1, power_iters=1) + return (out) + + def make_kldiv_loss(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( + name='x', + shape=[32, 128, 128], + dtype="float32", + append_batch_size=False) + target = self._get_data( + name='target', + shape=[32, 128, 128], + dtype="float32", + append_batch_size=False) + loss = layers.kldiv_loss(x=x, target=target, reduction='batchmean') + return (loss) + + def make_temporal_shift(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="X", shape=[16, 4, 4], dtype="float32") + out = layers.temporal_shift(x, seg_num=2, shift_ratio=0.2) + return (out) + + def make_shuffle_channel(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="X", shape=[16, 4, 4], dtype="float32") + out = layers.shuffle_channel(x, group=4) + return (out) + + def make_fsp_matrix(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="X", shape=[16, 4, 4], dtype="float32") + y = self._get_data(name="Y", shape=[8, 4, 4], dtype="float32") + out = layers.fsp_matrix(x, y) + return (out) + + def make_pixel_shuffle(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data(name="X", shape=[9, 4, 4], dtype="float32") + out = layers.pixel_shuffle(x, upscale_factor=3) + return (out) + + def test_dynamic_lstmp(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + hidden_dim, proj_dim = 16, 8 + seq_data = layers.data( + name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) + fc_out = layers.fc(input=seq_data, size=4 * hidden_dim) + self.assertIsNotNone( + layers.dynamic_lstmp( + input=fc_out, size=4 * hidden_dim, proj_size=proj_dim)) + + def test_linear_chain_crf(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + label_dict_len = 10 + images = layers.data(name='pixel', shape=[784], dtype='float32') + label = layers.data(name='label', shape=[1], dtype='int32') + hidden = layers.fc(input=images, size=2) + crf = layers.linear_chain_crf( + input=hidden, label=label, param_attr=ParamAttr(name="crfw")) + crf_decode = layers.crf_decoding( + input=hidden, param_attr=ParamAttr(name="crfw")) + self.assertFalse(crf is None) + self.assertFalse(crf_decode is None) + return layers.chunk_eval( + input=crf_decode, + label=label, + chunk_scheme="IOB", + num_chunk_types=(label_dict_len - 1) // 2) + + def test_im2sequence(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[3, 128, 128], dtype='float32') + y = layers.data(name='y', shape=[], dtype='float32') + output = layers.im2sequence( + input=x, + input_image_size=y, + stride=[1, 1], + filter_size=[2, 2], + out_stride=[1, 1]) + return (output) + + def test_lod_reset(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[10], dtype='float32') + y = layers.data( + name='y', shape=[10, 20], dtype='float32', lod_level=2) + return (layers.lod_reset(x=x, y=y)) def test_affine_grid(self): - program = Program() - with program_guard(program): + with self.static_graph(): data = layers.data(name='data', shape=[2, 3, 3], dtype="float32") out, ids = layers.argsort(input=data, axis=1) @@ -1860,81 +1777,153 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(data_0) self.assertIsNotNone(data_1) - print(str(program)) - def test_bilinear_tensor_product_layer(self): - program = Program() - with program_guard(program): - data = layers.data(name='data', shape=[4], dtype="float32") + def test_psroi_pool(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name="x", shape=[245, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.psroi_pool(x, rois, 5, 0.25, 7, 7) + return (output) + + def test_sequence_expand(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[10], dtype='float32') + y = layers.data( + name='y', shape=[10, 20], dtype='float32', lod_level=2) + return (layers.sequence_expand(x=x, y=y, ref_level=1)) - theta = layers.data(name="theta", shape=[5], dtype="float32") - out = layers.bilinear_tensor_product(data, theta, 6) + def test_sequence_reshape(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[8], dtype='float32', lod_level=1) + out = layers.sequence_reshape(input=x, new_dim=16) + return (out) - print(str(program)) + def test_sequence_unpad(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[10, 5], dtype='float32') + length = layers.data(name='length', shape=[1], dtype='int64') + return (layers.sequence_unpad(x=x, length=length)) - def test_batch_norm(self): - program = Program() - with program_guard(program): - data = layers.data( - name='data', shape=[32, 128, 128], dtype="float32") - out = layers.batch_norm(data) + def test_sequence_softmax(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + seq_data = layers.data( + name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) + seq = layers.fc(input=seq_data, size=20) + return (layers.sequence_softmax(seq)) - print(str(program)) + def test_sequence_unsqueeze(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[8, 2], dtype='float32') + out = layers.unsqueeze(input=x, axes=[1]) + return (out) - def test_range(self): - program = Program() - with program_guard(program): - layers.range(0, 10, 2, 'int32') - layers.range(0.1, 10.0, 0.2, 'float32') + def test_sequence_scatter(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data( + name='x', + shape=[3, 6], + append_batch_size=False, + dtype='float32') + idx = layers.data( + name='idx', + shape=[12, 1], + append_batch_size=False, + dtype='int32', + lod_level=1) + updates = layers.data( + name='updates', + shape=[12, 1], + append_batch_size=False, + dtype='float32', + lod_level=1) + out = layers.sequence_scatter(input=x, index=idx, updates=updates) + return (out) - print(str(program)) + def test_sequence_slice(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + import numpy as np + seqs = layers.data( + name='x', shape=[10, 5], dtype='float32', lod_level=1) + offset = layers.assign(input=np.array([[0, 1]]).astype('int32')) + length = layers.assign(input=np.array([[2, 1]]).astype('int32')) + out = layers.sequence_slice( + input=seqs, offset=offset, length=length) + return (out) - def test_spectral_norm(self): - program = Program() - with program_guard(program): - weight = layers.data( - name='weight', - shape=[2, 3, 32, 32], - dtype="float32", - append_batch_size=False) - out = layers.spectral_norm(weight, dim=1, power_iters=1) - self.assertIsNotNone(out) - - def test_kldiv_loss(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[32, 128, 128], dtype="float32") - target = layers.data( - name='target', shape=[32, 128, 128], dtype="float32") - loss = layers.kldiv_loss(x=x, target=target, reduction='batchmean') - self.assertIsNotNone(loss) - - print(str(program)) - - def test_temporal_shift(self): - program = Program() - with program_guard(program): - x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") - out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2) - self.assertIsNotNone(out) - print(str(program)) - - def test_shuffle_channel(self): - program = Program() - with program_guard(program): - x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") - out = layers.shuffle_channel(x, group=4) - self.assertIsNotNone(out) - print(str(program)) - - def test_fsp(self): - program = Program() - with program_guard(program): - x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") - y = layers.data(name="Y", shape=[8, 4, 4], dtype="float32") - out = layers.fsp_matrix(x, y) - self.assertIsNotNone(out) - print(str(program)) + def test_roi_pool(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.roi_pool(x, rois, 7, 7, 0.6) + return (output) + + def test_sequence_enumerate(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1) + out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) + + def test_roi_align(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.roi_align(x, rois, 14, 14, 0.5, 2) + return (output) + + def test_roi_perspective_transform(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name="x", shape=[256, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[8], dtype="float32", lod_level=1) + output = layers.roi_perspective_transform(x, rois, 7, 7, 0.6) + return (output) + + def test_row_conv(self): + # TODO(minqiyang): dygraph do not support lod now + with self.static_graph(): + x = layers.data(name='x', shape=[16], dtype='float32', lod_level=1) + out = layers.row_conv(input=x, future_context_size=2) + return (out) + + def test_simple_conv2d(self): + # TODO(minqiyang): dygraph do not support layers with param now + with self.static_graph(): + images = layers.data( + name='pixel', shape=[3, 48, 48], dtype='float32') + return layers.conv2d( + input=images, num_filters=3, filter_size=[4, 4]) + + def test_squeeze(self): + # TODO(minqiyang): dygraph do not support layers with param now + with self.static_graph(): + x = layers.data(name='x', shape=[1, 1, 4], dtype='float32') + out = layers.squeeze(input=x, axes=[2]) + return (out) + + def test_flatten(self): + # TODO(minqiyang): dygraph do not support op without kernel now + with self.static_graph(): + x = layers.data( + name='x', + append_batch_size=False, + shape=[4, 4, 3], + dtype="float32") + out = layers.flatten(x, axis=1, name="flatten") + return (out) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3ae2b3b9d4c40a7ee992c04cac79f518acac6d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -0,0 +1,50 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestPixelShuffle(OpTest): + def setUp(self): + self.op_type = "pixel_shuffle" + n, c, h, w = 2, 9, 4, 4 + up_factor = 3 + shape = [n, c, h, w] + x = np.random.random(shape).astype("float32") + new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h, + w) + # reshape to (num,output_channel,upscale_factor,upscale_factor,h,w) + npresult = np.reshape(x, new_shape) + # transpose to (num,output_channel,h,upscale_factor,w,upscale_factor) + npresult = npresult.transpose(0, 1, 4, 2, 5, 3) + oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor] + npresult = np.reshape(npresult, oshape) + + self.inputs = {'X': x} + self.outputs = {'Out': npresult} + self.attrs = {'upscale_factor': up_factor} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 9ab4e9742cfbaf4e2d08e7c27b6ba231c85c4ec2..eef8afac65225e78f1f5bff35d74311e6450191c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -157,10 +157,6 @@ package_data['paddle.libs']= [] package_data['paddle.libs']=[('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name] shutil.copy('${WARPCTC_LIBRARIES}', libs_path) -if '${WITH_WBAES}' == 'ON': - package_data['paddle.libs'] += ['libwbaes' + ext_name] - shutil.copy('${WBAES_SHARED_LIB}', libs_path) - if '${WITH_MKL}' == 'ON': shutil.copy('${MKLML_SHARED_LIB}', libs_path) shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)