diff --git a/AUTHORS.md b/AUTHORS.md index 4ee05420982d13f686cf13e8957ce41dfcdd2cb8..11f227be7148d8d6e055538347a8c31679406c84 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -4,6 +4,7 @@ | backyes | Yan-Fei Wang | | baiyfbupt | Yi-Fan Bai | | beckett1124 | Bin Qi | +| ChengduoZH | Cheng-Duo Zhao| | chengxiaohua1105 | Xiao-Hua Cheng | | cxwangyi, yiwangbaidu, wangkuiyi | Yi Wang | | cxysteven | Xing-Yi Cheng | diff --git a/doc/fluid/howto/index_cn.rst b/doc/fluid/howto/index_cn.rst index b7c620179724ebe97a0a47b75a57b376b21ccf90..b57af64f44da82926c4862578f3072960ca5aa92 100644 --- a/doc/fluid/howto/index_cn.rst +++ b/doc/fluid/howto/index_cn.rst @@ -4,5 +4,5 @@ .. toctree:: :maxdepth: 1 + inference/index_cn.rst optimization/index_cn.rst - inference/inference_support_in_fluid.md diff --git a/doc/fluid/howto/index_en.rst b/doc/fluid/howto/index_en.rst index f3ca41cdbf1d40ec8afaf045233a38755d8a777a..fd21e167ce3a46da167db1e9d7013804f730e047 100644 --- a/doc/fluid/howto/index_en.rst +++ b/doc/fluid/howto/index_en.rst @@ -5,4 +5,3 @@ HOW TO :maxdepth: 1 optimization/index_en.rst - inference/inference_support_in_fluid.md diff --git a/doc/fluid/howto/inference/build_and_install_lib_cn.rst b/doc/fluid/howto/inference/build_and_install_lib_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..c8d9992fcc92c25f8c14f71c79bde9f79fd92b1f --- /dev/null +++ b/doc/fluid/howto/inference/build_and_install_lib_cn.rst @@ -0,0 +1,96 @@ +安装与编译C++预测库 +=========================== + +直接下载安装 +------------- + +====================== ======================================== +版本说明 C++预测库 +====================== ======================================== +cpu_avx_mkl `fluid.tgz `_ +cpu_avx_openblas `fluid.tgz `_ +cpu_noavx_openblas `fluid.tgz `_ +cuda7.5_cudnn5_avx_mkl `fluid.tgz `_ +cuda8.0_cudnn5_avx_mkl `fluid.tgz `_ +cuda8.0_cudnn7_avx_mkl `fluid.tgz `_ +====================== ======================================== + +从源码编译 +---------- +用户也可以从 PaddlePaddle 核心代码编译C++预测库,只需在编译时配制下面这些编译选项: + +================= ========= +选项 值 +================= ========= +CMAKE_BUILD_TYPE Release +FLUID_INSTALL_DIR 安装路径 +WITH_FLUID_ONLY ON(推荐) +WITH_SWIG_PY OFF(推荐 +WITH_PYTHON OFF(推荐) +WITH_GPU ON/OFF +WITH_MKL ON/OFF +================= ========= + +建议按照推荐值设置,以避免链接不必要的库。其它可选编译选项按需进行设定。 + +下面的代码片段从github拉取最新代码,配制编译选项(需要将PADDLE_ROOT替换为PaddlePaddle预测库的安装路径): + + .. code-block:: bash + + pip install paddlepaddle-gpu + PADDLE_ROOT=/path/of/capi + git clone https://github.com/PaddlePaddle/Paddle.git + cd Paddle + mkdir build + cd build + cmake -DFLUID_INSTALL_DIR=$PADDLE_ROOT \ + -DCMAKE_BUILD_TYPE=Release \ + -DWITH_FLUID_ONLY=ON \ + -DWITH_SWIG_PY=OFF \ + -DWITH_PYTHON=OFF \ + -DWITH_MKL=OFF \ + -DWITH_GPU=OFF \ + .. + make + make inference_lib_dist + +成功编译后,使用C++预测库所需的依赖(包括:(1)编译出的PaddlePaddle预测库和头文件;(2)第三方链接库和头文件;(3)版本信息与编译选项信息) +均会存放于PADDLE_ROOT目录中。目录结构如下: + + .. code-block:: text + + PaddleRoot/ + ├── CMakeCache.txt + ├── paddle + │   └── fluid + │   ├── framework + │   ├── inference + │   ├── memory + │   ├── platform + │   ├── pybind + │   └── string + ├── third_party + │   ├── boost + │   │   └── boost + │   ├── eigen3 + │   │   ├── Eigen + │   │   └── unsupported + │   └── install + │   ├── gflags + │   ├── glog + │   ├── mklml + │   ├── protobuf + │   ├── snappy + │   ├── snappystream + │   └── zlib + └── version.txt + +version.txt 中记录了该预测库的版本信息,包括Git Commit ID、使用OpenBlas或MKL数学库、CUDA/CUDNN版本号,如: + + .. code-block:: text + + GIT COMMIT ID: c95cd4742f02bb009e651a00b07b21c979637dc8 + WITH_MKL: ON + WITH_GPU: ON + CUDA version: 8.0 + CUDNN version: v5 diff --git a/doc/fluid/howto/inference/index_cn.rst b/doc/fluid/howto/inference/index_cn.rst new file mode 100644 index 0000000000000000000000000000000000000000..a903423548decd0992bf19772fb2cb143f6a12b5 --- /dev/null +++ b/doc/fluid/howto/inference/index_cn.rst @@ -0,0 +1,8 @@ +预测库 +------------ + +.. toctree:: + :maxdepth: 1 + + build_and_install_lib_cn.rst + inference_support_in_fluid_cn.md diff --git a/doc/fluid/howto/inference/inference_support_in_fluid.md b/doc/fluid/howto/inference/inference_support_in_fluid_cn.md similarity index 90% rename from doc/fluid/howto/inference/inference_support_in_fluid.md rename to doc/fluid/howto/inference/inference_support_in_fluid_cn.md index d272cd3e3bdac49b9ed1a21531de1b0be03d881e..309b17fccd5c461c9c22beb64eb4c6792b7e4a7a 100644 --- a/doc/fluid/howto/inference/inference_support_in_fluid.md +++ b/doc/fluid/howto/inference/inference_support_in_fluid_cn.md @@ -1,9 +1,8 @@ -# Fluid Inference使用指南 +# 使用指南 ## 目录: - Python Inference API -- 编译Fluid Inference库 - Inference C++ API - Inference实例 - Inference计算优化 @@ -55,62 +54,6 @@ return [program, feed_target_names, fetch_targets] ``` - -## 编译Fluid Inference库 - - - **不需要额外的CMake选项** - - 1、 配置CMake命令,更多配置请参考[源码编译PaddlePaddle](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/build_from_source_cn.html) - ```bash - $ git clone https://github.com/PaddlePaddle/Paddle.git - $ cd Paddle - $ mkdir build - $ cd build - $ cmake -DCMAKE_INSTALL_PREFIX=your/path/to/paddle_inference_lib \ - -DCMAKE_BUILD_TYPE=Release \ - -DWITH_PYTHON=ON \ - -DWITH_MKL=OFF \ - -DWITH_GPU=OFF \ - .. - ``` - - - 2、 编译PaddlePaddle - ```bash - $ make - ``` - - - 3、 部署。执行如下命令将PaddlePaddle Fluid Inference库部署到`your/path/to/paddle_inference_lib`目录。 - ```bash - $ make inference_lib_dist - ``` - -- 目录结构 - - ```bash - $ cd your/path/to/paddle_inference_lib - $ tree - . - |-- paddle - | `-- fluid - | |-- framework - | |-- inference - | | |-- io.h - | | `-- libpaddle_fluid.so - | |-- memory - | |-- platform - | `-- string - |-- third_party - | |-- eigen3 - | `-- install - | |-- gflags - | |-- glog - | `-- protobuf - `-- ... - ``` - - 假设`PADDLE_ROOT=your/path/to/paddle_inference_lib`。 - - - ## 链接Fluid Inference库 - 示例项目([链接](https://github.com/luotao1/fluid_inference_example.git)) diff --git a/paddle/contrib/inference/CMakeLists.txt b/paddle/contrib/inference/CMakeLists.txt index 3beb93c4e7fd1ce4dd2131cb53cb6e89e0f10ebd..6847f7db7fc0f6b41ced1260d409ca6eba9b53eb 100644 --- a/paddle/contrib/inference/CMakeLists.txt +++ b/paddle/contrib/inference/CMakeLists.txt @@ -17,32 +17,21 @@ if(APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move") endif(APPLE) -function(inference_api_test TARGET_NAME TEST_SRC) +function(inference_api_test TARGET_NAME) set(options "") set(oneValueArgs "") set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) - set(arg_list "") + cc_test(test_paddle_inference_${TARGET_NAME} + SRCS test_paddle_inference_${TARGET_NAME}.cc + DEPS paddle_fluid_api paddle_inference_api + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) if(inference_test_ARGS) - foreach(arg ${inference_test_ARGS}) - list(APPEND arg_list "_${arg}") - endforeach() - else() - list(APPEND arg_list "_") + set_tests_properties(test_paddle_inference_${TARGET_NAME} + PROPERTIES DEPENDS "${inference_test_ARGS}") endif() - foreach(arg ${arg_list}) - string(REGEX REPLACE "^_$" "" arg "${arg}") - cc_test(${TARGET_NAME} - SRCS ${TEST_SRC} - DEPS paddle_fluid_api paddle_inference_api - ARGS --dirname=${PYTHON_TESTS_DIR}/book/) - # TODO(panyx0178): Figure out how to add word2vec and image_classification - # as deps. - # set_tests_properties(${TARGET_NAME} - # PROPERTIES DEPENDS ${DEP_TEST}) - endforeach() endfunction(inference_api_test) @@ -50,9 +39,11 @@ cc_library(paddle_inference_api SRCS paddle_inference_api.cc paddle_inference_api_impl.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) -cc_test(test_paddle_inference_api - SRCS test_paddle_inference_api.cc - DEPS paddle_inference_api) +if(WITH_TESTING) + cc_test(test_paddle_inference_api + SRCS test_paddle_inference_api.cc + DEPS paddle_inference_api) -inference_api_test(test_paddle_inference_api_impl - test_paddle_inference_api_impl.cc) + inference_api_test(api_impl + ARGS test_word2vec test_image_classification) +endif() diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fd409ed4c0f7a504686765909e9c71692aab8824..e7842e9b8130d35e511e02dfb1dc27f307d17f38 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) vars_[var_desc.name()].reset(new VarDesc(var_desc)); } for (const proto::OpDesc &op_desc : desc_->ops()) { - ops_.emplace_back(new OpDesc(op_desc, prog, this)); + ops_.emplace_back(new OpDesc(op_desc, this)); } } @@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, : prog_(prog), desc_(desc) { need_update_ = true; for (auto &op : other.ops_) { - ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); + ops_.emplace_back(new OpDesc(*op, this)); } for (auto &it : other.vars_) { auto *var = new VarDesc(*it.second); diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 600601669c5d56a3ffc2fb9c804ffad5fde58f0b..189dd6c52f85b5bf623b98c64c07c0c7269505d4 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -105,7 +105,7 @@ class BlockDesc { size_t OpSize() const { return ops_.size(); } - OpDesc *Op(int idx) { return ops_.at(idx).get(); } + OpDesc *Op(int idx) const { return ops_.at(idx).get(); } void Flush(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d8e711994c5dba15ce0a1c237558b121888902e3..17baacd13eecac8f410631fe9e94788da4fff848 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -11,11 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" +#include #include +#include #include +#include + #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" @@ -26,9 +30,6 @@ #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif -#include -#include - DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot", "the ssa graph path only print with GLOG_v=10," "default /tmp/graph.dot"); @@ -148,9 +149,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { - std::unordered_map var_types; + std::unordered_map all_vars; for (auto *var : program.Block(0).AllVars()) { - var_types[var->Name()] = var->GetType(); + all_vars[var->Name()] = var; } auto graph = new SSAGraph(); @@ -167,12 +168,28 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( auto send_vars = FindDistTrainSendVars(program); auto recv_vars = FindDistTrainRecvVars(program); - size_t cur_device_id = 0; std::vector> var_name_on_devices; std::vector> bcast_var_name_set; var_name_on_devices.resize(places_.size()); bcast_var_name_set.resize(places_.size()); + size_t cur_device_id = 0; + std::vector balance_grads(places_.size(), 0); + + auto get_appropriate_dev = [&](std::string &g_name) -> size_t { + auto var_desc = all_vars.at(g_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GE(numel, 0); + auto smallest = + std::min_element(std::begin(balance_grads), std::end(balance_grads)); + size_t dev_id = + static_cast(std::distance(std::begin(balance_grads), smallest)); + balance_grads[dev_id] += numel; + return dev_id; + }; + bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (boost::get( @@ -220,13 +237,13 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: + cur_device_id = get_appropriate_dev(g_name); CreateReduceOp(&result, g_name, cur_device_id); var_name_on_devices[cur_device_id].emplace(g_name); bcast_var_name_set[cur_device_id].emplace(p_name); - cur_device_id = (cur_device_id + 1) % places_.size(); break; case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(var_types, g_name)) { + if (IsSparseGradient(all_vars, g_name)) { CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { @@ -269,10 +286,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } bool MultiDevSSAGraphBuilder::IsSparseGradient( - const std::unordered_map &var_types, + const std::unordered_map &all_vars, const std::string &og) const { - PADDLE_ENFORCE(var_types.count(og) != 0); - if (var_types.at(og) == proto::VarType::SELECTED_ROWS) { + PADDLE_ENFORCE(all_vars.count(og) != 0); + if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { return true; } return false; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index e07597dbd80889c366babe79455beb12c9eb80d9..544cbe585c7423b5f3eb98ee698ca5668376f1ca 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -106,7 +106,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { size_t src_dev_id) const; bool IsSparseGradient( - const std::unordered_map &var_types, + const std::unordered_map &all_vars, const std::string &og) const; private: diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 09b67e5a1741c68c5f5487340e8fc86ff31e00a4..f92769192c218eb7cdc2350ff6e4721b45005806 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { need_update_ = true; } -OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) +OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) : desc_(desc), need_update_(false) { // restore inputs_ int input_size = desc_.inputs_size(); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 1a330db7cc5555a939950043ac90a321573b292d..a02d3e269129596f65a2fb346e76c1af7fbead95 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -33,13 +33,14 @@ class OpDesc { OpDesc(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs); - OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); + OpDesc(const proto::OpDesc &desc, BlockDesc *block); explicit OpDesc(BlockDesc *block) : block_(block) {} OpDesc(const OpDesc &other, BlockDesc *block) { *this = other; block_ = block; + need_update_ = true; } void CopyFrom(const OpDesc &op_desc); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 64fb028f83a539d17885186d5d8ee6ef26f095e9..1e01a6e900404990e16674755367d2fc6d832725 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { + auto all_ops = blocks_[block_id]->AllOps(); + for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { + auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { + if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { + int sub_block_id = + o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name); + op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); } } } @@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } + for (auto &block : blocks_) { + for (auto *op : block->AllOps()) { + for (const auto &attr : op->Proto()->attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + size_t blk_idx = attr.block_idx(); + op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } + } + } + } } const std::vector ProgramDesc::GetFeedTargetNames() { diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1..79d01b640a214ed5eb86173a36d5e85a6626066f 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. - framework::OpDesc op_desc(op, nullptr, nullptr); + framework::OpDesc op_desc(op, nullptr); LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " "type is Relu"; const nvinfer1::ITensor* input_tensor = diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index ed09f54bde00d12aaec829ba90cc08ebfef57e92..aa8e66490f7e40038b0de4da32655f1b168ca332 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -27,7 +27,7 @@ class MulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op) override { VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias"; - framework::OpDesc op_desc(op, nullptr, nullptr); + framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index e46c577cdae145c0d4ceb6bfa307f03d313514ce..684bbc208fc1cb02d2a36b4de720309ea6bed173 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -104,7 +104,7 @@ class TRTConvertValidation { engine_->FreezeNetwork(); // Declare outputs. - op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr)); + op_desc_.reset(new framework::OpDesc(desc, nullptr)); // Set Inputs. for (const auto& input : op_desc_->InputArgumentNames()) { diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/bilinear_interp_op.cc index d46fda54e7a9d5bc737a7ec2116daca33ffa015f..3321adf2743c28f6eeca8b5cc91ef89beed6b97c 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cc +++ b/paddle/fluid/operators/bilinear_interp_op.cc @@ -34,9 +34,22 @@ class BilinearInterpOp : public framework::OperatorWithKernel { int out_w = ctx->Attrs().Get("out_w"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); + if (ctx->HasInput("OutSize")) { + auto out_size_dim = ctx->GetInputDim("OutSize"); + PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, + "OutSize's dimension size must be 1"); + PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2"); + } std::vector dim_out({dim_x[0], dim_x[1], out_h, out_w}); ctx->SetOutputDim("Out", framework::make_ddim(dim_out)); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } }; class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { @@ -45,6 +58,10 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) The input tensor of bilinear interpolation, " "This is a 4-D tensor with shape of (N x C x h x w)"); + AddInput("OutSize", + "(Tensor) This is a 1-D tensor with two number. " + "The first number is height and the second number is width.") + .AsDispensable(); AddOutput("Out", "(Tensor) The dimension of output is (N x C x out_h x out_w]"); @@ -78,6 +95,12 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), dim_x); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + } }; } // namespace operators diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu index 510190f1aaf02960284216a1bedd409011088499..4c1971538495c6f111e9db18f4014786f6f0dd58 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ b/paddle/fluid/operators/bilinear_interp_op.cu @@ -102,10 +102,21 @@ class BilinearInterpOpCUDAKernel : public framework::OpKernel { auto* input_t = ctx.Input("X"); // float tensor auto* output_t = ctx.Output("Out"); // float tensor auto* input = input_t->data(); - auto* output = output_t->mutable_data(ctx.GetPlace()); int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); + auto out_dims = output_t->dims(); + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + auto* output = output_t->mutable_data( + {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); + int batch_size = input_t->dims()[0]; int channels = input_t->dims()[1]; int in_h = input_t->dims()[2]; @@ -139,8 +150,8 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* d_input_t = ctx.Output(framework::GradVarName("X")); auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); + auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); @@ -149,6 +160,16 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); + + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + Tensor sizes; + framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes); + auto size_data = sizes.data(); + out_h = size_data[0]; + out_w = size_data[1]; + } + int batch_size = d_input_t->dims()[0]; int channels = d_input_t->dims()[1]; int in_h = d_input_t->dims()[2]; diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h index f6cd77e4d49b53ecde6a84908cdffc7e1e02ac6a..8b03cd5a0635584a45782fe5a4823c37fe4fa8e8 100644 --- a/paddle/fluid/operators/bilinear_interp_op.h +++ b/paddle/fluid/operators/bilinear_interp_op.h @@ -24,11 +24,18 @@ class BilinearInterpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* input_t = ctx.Input("X"); // float tensor auto* output_t = ctx.Output("Out"); // float tensor + auto out_dims = output_t->dims(); auto* input = input_t->data(); - auto* output = output_t->mutable_data(ctx.GetPlace()); - int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + auto out_size_data = out_size_t->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + auto* output = output_t->mutable_data( + {out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace()); int batch_size = input_t->dims()[0]; int channels = input_t->dims()[1]; int in_h = input_t->dims()[2]; @@ -83,9 +90,8 @@ class BilinearInterpGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* d_input_t = ctx.Output(framework::GradVarName("X")); auto* d_output_t = ctx.Input(framework::GradVarName("Out")); - auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); - + auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto& device_ctx = ctx.template device_context(); math::SetConstant zero; @@ -93,6 +99,14 @@ class BilinearInterpGradKernel : public framework::OpKernel { int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); + + auto out_size_t = ctx.Input("OutSize"); + if (out_size_t != nullptr) { + auto out_size_data = out_size_t->data(); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + int batch_size = d_input_t->dims()[0]; int channels = d_input_t->dims()[1]; int in_h = d_input_t->dims()[2]; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 63ec83151477770ea64070cae4f5e4fcc497f7af..cb87653c47727052bb7835dcbe0a590887a560c1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3944,7 +3944,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): input (Variable): The input tensor of bilinear interpolation, This is a 4-D tensor of the shape (num_batches, channels, in_h, in_w). - out_shape(list|tuple|None): Output shape of bilinear interpolation + out_shape(list|tuple|Variable|None): Output shape of bilinear interpolation layer, the shape is (out_h, out_w). Default: None scale(int|None): The multiplier for the input height or width. @@ -3971,13 +3971,20 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): def _is_list_or_turple_(data): return (isinstance(data, list) or isinstance(data, tuple)) + out_h = 0 + out_w = 0 + inputs = {"X": input} if out_shape is not None: - if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2): + if not (_is_list_or_turple_(out_shape) and len(out_shape) == 2) and ( + out_shape is not Variable): raise ValueError('out_shape should be a list or tuple ', 'with length 2, (out_h, out_w).') - out_shape = list(map(int, out_shape)) - out_h = out_shape[0] - out_w = out_shape[1] + if _is_list_or_turple_(out_shape): + out_shape = list(map(int, out_shape)) + out_h = out_shape[0] + out_w = out_shape[1] + else: + inputs['OutSize'] = out_shape else: out_h = int(input.shape[2] * scale) out_w = int(input.shape[3] * scale) @@ -3985,7 +3992,7 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): out = helper.create_tmp_variable(dtype) helper.append_op( type="bilinear_interp", - inputs={"X": input}, + inputs=inputs, outputs={"Out": out}, attrs={"out_h": out_h, "out_w": out_w}) diff --git a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index bffb4f3b666a7ddcc133b7c30fab132b49aa1d0e..87c11e7880e73b911f21dda77c1cc2b4850b3591 100644 --- a/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -17,7 +17,10 @@ import numpy as np from op_test import OpTest -def bilinear_interp_np(input, out_h, out_w): +def bilinear_interp_np(input, out_h, out_w, out_size): + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] batch_size, channel, in_h, in_w = input.shape if out_h > 1: ratio_h = (in_h - 1.0) / (out_h - 1.0) @@ -49,12 +52,15 @@ def bilinear_interp_np(input, out_h, out_w): class TestBilinearInterpOp(OpTest): def setUp(self): + self.out_size = None self.init_test_case() self.op_type = "bilinear_interp" input_np = np.random.random(self.input_shape).astype("float32") - output_np = bilinear_interp_np(input_np, self.out_h, self.out_w) - + output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, + self.out_size) self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size self.attrs = {'out_h': self.out_h, 'out_w': self.out_w} self.outputs = {'Out': output_np} @@ -68,6 +74,7 @@ class TestBilinearInterpOp(OpTest): self.input_shape = [2, 3, 4, 4] self.out_h = 2 self.out_w = 2 + self.out_size = np.array([3, 3]).astype("int32") class TestCase1(TestBilinearInterpOp): @@ -91,5 +98,29 @@ class TestCase3(TestBilinearInterpOp): self.out_w = 128 +class TestCase4(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestCase5(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestCase6(TestBilinearInterpOp): + def init_test_case(self): + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + if __name__ == "__main__": unittest.main()