diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 8a8aeb5e09a0d9a6746f6d6d61c547363e0e2d30..d780592eb9f79e39e34fcd3bd6b086992eaa931f 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -70,7 +70,7 @@ class DfgPassManagerImpl final : public DfgPassManager { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", - "depthwise_conv2d", "batch_norm", "concat", "tanh", + "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "elementwise_add", "dropout"}); if (!node->IsFunction()) return false; diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 5ee6a5a93168f58770067f76ca7f6bb6f67b2965..7ac468ee4d33f49bba20a07c976055a083743cbc 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -185,3 +185,4 @@ USE_TRT_CONVERTER(softmax); USE_TRT_CONVERTER(batch_norm); USE_TRT_CONVERTER(concat); USE_TRT_CONVERTER(dropout); +USE_TRT_CONVERTER(pad); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index fac1babf6ec6131f84d3e3b9fc6efedd9f9f6cfc..0a35e10f6936313928ab21a6f17c40335e8fc882 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -26,6 +26,8 @@ nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) - nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL) + +nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/pad_op.cc b/paddle/fluid/inference/tensorrt/convert/pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..218030a591fcc7e533ef37062265449d4b6044bc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/pad_op.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * PadOp. + */ +class PadOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid transpose op to tensorrt tranpose layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + + const std::vector paddings = + boost::get>(op_desc.GetAttr("paddings")); + const float pad_value = boost::get(op_desc.GetAttr("pad_value")); + + nvinfer1::Dims input_shape = input->getDimensions(); + int nbDims = input_shape.nbDims; + int pad_size = static_cast(paddings.size()); + PADDLE_ENFORCE_GE(nbDims, 2); + PADDLE_ENFORCE_EQ((nbDims + 1) * 2, pad_size); + PADDLE_ENFORCE(pad_value == 0.0, "The pad layer of TRT only support zero."); + + nvinfer1::DimsHW pre_pad(paddings[pad_size - 4], paddings[pad_size - 2]); + nvinfer1::DimsHW post_pad(paddings[pad_size - 3], paddings[pad_size - 1]); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Padding, + *const_cast(input), + pre_pad, post_pad); + + PADDLE_ENFORCE(layer != nullptr); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + layer->setName(("scale (Output: " + output_name + ")").c_str()); + layer->getOutput(0)->setName(output_name.c_str()); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(pad, PadOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba35d7ddbb2f4e6062713bd82be277e7ad0cb341 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_pad_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(PadConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("pad-X", nvinfer1::Dims3(3, 2, 2)); + validator.DeclOutputVar("pad-Out", nvinfer1::Dims3(3, 3, 5)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("pad"); + desc.SetInput("X", {"pad-X"}); + desc.SetOutput("Out", {"pad-Out"}); + + std::vector paddings = {0, 0, 0, 0, 0, 1, 1, 2}; + float pad_value = 0.0; + desc.SetAttr("paddings", paddings); + desc.SetAttr("pad_value", pad_value); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + validator.Execute(2); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(pad); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 91101356436c26171eaca2fe01dfd4d937e71717..b0276f4080b3989047f9d181e9d37a18dcfad5fa 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -3,8 +3,8 @@ add_subdirectory(detail) endif(NOT WIN32) function(math_library TARGET) - # math_library is a function to create math library. - # The interface is the same as cc_library. + # math_library is a function to create math library. + # The interface is the same as cc_library. # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) @@ -53,7 +53,7 @@ cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) math_library(math_function DEPS blas) math_library(maxouting) math_library(pooling) -math_library(selected_rows_functor DEPS selected_rows math_function) +math_library(selected_rows_functor DEPS selected_rows math_function blas) math_library(sequence2batch) math_library(sequence_padding) math_library(sequence_pooling DEPS math_function) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 374198f75e7eeac5948110517450fb36aa841219..34fb168036cd2498860bf4f3f1d10e875f650804 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" namespace paddle { @@ -150,6 +149,45 @@ template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; +template +struct SelectedRowsSumTo { + void operator()(const platform::CPUDeviceContext& context, + const std::vector& input1, + const std::vector& input2_offsets, + framework::SelectedRows* input2) { + // Ensure all selected rows have the same height + size_t size = 0u; + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + auto& in_rows = (*iter)->rows(); + size += in_rows.end() - in_rows.begin(); + auto in1_height = (*iter)->height(); + PADDLE_ENFORCE_EQ(in1_height, input2->height()); + } + // concat rows + std::vector in2_rows; + in2_rows.reserve(in2_rows.size() + size); + for (auto iter = input1.begin(); iter != input1.end(); ++iter) { + const framework::Vector& in_rows = (*iter)->rows(); + in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end()); + } + input2->set_rows(in2_rows); + + auto* in2_value = input2->mutable_value(); + auto* in2_data = in2_value->data(); + auto blas = math::GetBlas(context); + size_t offset = 0u; + for (size_t i = 0u; i != input1.size(); ++i) { + auto& in_value = input1[i]->value(); + const auto* in_data = in_value.data(); + offset += input2_offsets[i]; + blas.VCOPY(in_value.numel(), in_data, in2_data + offset); + } + } +}; + +template struct SelectedRowsSumTo; +template struct SelectedRowsSumTo; + template struct SelectedRowsAddToTensor { void operator()(const platform::CPUDeviceContext& context, @@ -260,8 +298,6 @@ struct MergeAdd { } }; -template struct MergeAdd; -template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index dfabebcded1bab6fd161ac6c97168341cb59d6e3..f003bcd8db20664d37066615abe619998b0fcd94 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #define INLINE_FOR2(sizei, sizej) \ @@ -51,6 +54,15 @@ struct SelectedRowsAddTo { const int64_t input2_offset, framework::SelectedRows* input2); }; +// input2 = [all input in input1] + input2 +template +struct SelectedRowsSumTo { + void operator()(const DeviceContext& context, + const std::vector& input1, + const std::vector& input2_offsets, + framework::SelectedRows* input2); +}; + // input2 = input1 + input2 template struct SelectedRowsAddToTensor { @@ -75,6 +87,108 @@ struct MergeAdd { framework::SelectedRows* output); }; +template <> +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; + auto input_rows = input.rows(); + std::vector merge_rows; + merge_rows.reserve(input_rows.size()); + std::unordered_map rows_pos_map; + rows_pos_map.reserve(input_rows.size()); + size_t idx = 0u; + for (std::vector::iterator iter = input_rows.begin(); + iter != input_rows.end(); ++iter) { + if (rows_pos_map.find(*iter) == rows_pos_map.end()) { + rows_pos_map[*iter] = idx++; + merge_rows.emplace_back(*iter); + } + } + + auto input_width = input.value().dims()[1]; + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + auto blas = GetBlas(context); + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_pos_map[input_rows[i]]; + float* y = out_data + out_i * input_width; + const float* x = input_data + i * input_width; + blas.AXPY(input_width, 1., x, y); + } + } +}; + +template <> +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; + auto input_rows = input.rows(); + std::vector merge_rows; + merge_rows.reserve(input_rows.size()); + std::unordered_map rows_pos_map; + rows_pos_map.reserve(input_rows.size()); + size_t idx = 0u; + for (std::vector::iterator iter = input_rows.begin(); + iter != input_rows.end(); ++iter) { + if (rows_pos_map.find(*iter) == rows_pos_map.end()) { + rows_pos_map[*iter] = idx++; + merge_rows.emplace_back(*iter); + } + } + + auto input_width = input.value().dims()[1]; + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + auto blas = GetBlas(context); + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_pos_map[input_rows[i]]; + double* y = out_data + out_i * input_width; + const double* x = input_data + i * input_width; + blas.AXPY(input_width, 1., x, y); + } + } +}; + template struct Add { framework::SelectedRows operator()(const DeviceContext& context, diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index 2a2fa652be10a2f0c72a940964af3e1705aa99f2..e114e58dee23443750b8fe470c20a1dae84d5903 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -220,17 +220,98 @@ TEST(selected_rows_functor, cpu_add_to) { EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0); } -TEST(selected_rows_functor, cpu_merge_add) { +TEST(selected_rows_functor, cpu_merge_add_float) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place); paddle::operators::math::SetConstant - set_const; + functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows{0, 4, 4, 7}; + std::unique_ptr selected_rows{ + new paddle::framework::SelectedRows(rows, height)}; + auto* in_value = selected_rows->mutable_value(); + in_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + cpu_place); + functor(ctx, in_value, 1.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + merge_add_functor(ctx, *selected_rows, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + + auto* out_data = output->value().data(); + EXPECT_EQ(out_data[0 * row_numel], 1.0); + EXPECT_EQ(out_data[1 * row_numel], 2.0); + EXPECT_EQ(out_data[2 * row_numel], 1.0); +} + +TEST(selected_rows_functor, cpu_merge_add_int) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + functor; int64_t height = 10; - int64_t row_numel = 8; + int64_t row_numel = 10; - std::vector rows1{5, 2, 5, 3, 5}; + std::vector rows{0, 4, 4, 7}; + std::unique_ptr selected_rows{ + new paddle::framework::SelectedRows(rows, height)}; + auto* in_value = selected_rows->mutable_value(); + in_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + cpu_place); + functor(ctx, in_value, 1); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + merge_add_functor(ctx, *selected_rows, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + + auto* out_data = output->value().data(); + + EXPECT_EQ(out_data[0 * row_numel], 1); + EXPECT_EQ(out_data[1 * row_numel], 2); + EXPECT_EQ(out_data[2 * row_numel], 1); +} + +TEST(selected_rows_functor, cpu_sum_to) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + functor; + int64_t height = 10; + int64_t row_numel = 10; + std::vector rows1{0, 4, 7}; std::unique_ptr selected_rows1{ new paddle::framework::SelectedRows(rows1, height)}; auto* in1_value = selected_rows1->mutable_value(); @@ -238,9 +319,9 @@ TEST(selected_rows_functor, cpu_merge_add) { paddle::framework::make_ddim( {static_cast(rows1.size()), row_numel}), cpu_place); - set_const(ctx, in1_value, 1.0); - std::vector rows2{2, 5, 3, 5, 3}; + functor(ctx, in1_value, 1.0); + std::vector rows2{0, 5, 7, 9}; std::unique_ptr selected_rows2{ new paddle::framework::SelectedRows(rows2, height)}; auto* in2_value = selected_rows2->mutable_value(); @@ -248,33 +329,67 @@ TEST(selected_rows_functor, cpu_merge_add) { paddle::framework::make_ddim( {static_cast(rows2.size()), row_numel}), cpu_place); - set_const(ctx, in2_value, 1.0); + functor(ctx, in2_value, 2.0); std::unique_ptr output{ new paddle::framework::SelectedRows()}; output->set_height(height); - paddle::operators::math::scatter::MergeAddmutable_value(); + // simplely concat two SelectedRows + out_value->mutable_data(paddle::framework::make_ddim({7, 10}), + cpu_place); + paddle::operators::math::SelectedRowsSumTo - merge_add_functor; - - std::vector inputs; - inputs.push_back(selected_rows1.get()); - inputs.push_back(selected_rows2.get()); - merge_add_functor(ctx, inputs, output.get()); - - EXPECT_EQ(output->height(), height); - EXPECT_EQ(output->value().dims(), - paddle::framework::make_ddim({3, row_numel})); - - std::vector ret_rows{2, 3, 5}; - EXPECT_EQ(output->rows(), ret_rows); - + sum_to_functor; + sum_to_functor(ctx, std::vector( + {selected_rows1.get(), selected_rows2.get()}), + std::vector({0, in1_value->numel()}), output.get()); + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + auto& out_rows = output->rows(); + // input1 rows + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + // input2 rows + EXPECT_EQ(out_rows[3], 0); + EXPECT_EQ(out_rows[4], 5); + EXPECT_EQ(out_rows[5], 7); + EXPECT_EQ(out_rows[6], 9); auto* out_data = output->value().data(); - for (size_t i = 0; i < ret_rows.size(); ++i) { - for (size_t j = 0; j < row_numel; ++j) { - EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); - std::cout << out_data[i * row_numel + j] << " "; - } - std::cout << "\n"; - } + // input1 value + EXPECT_EQ(out_data[0 * row_numel + 0], 1.0); + EXPECT_EQ(out_data[0 * row_numel + 8], 1.0); + EXPECT_EQ(out_data[1 * row_numel + 1], 1.0); + EXPECT_EQ(out_data[2 * row_numel + 6], 1.0); + // input2 value + EXPECT_EQ(out_data[3 * row_numel + 3], 2.0); + EXPECT_EQ(out_data[3 * row_numel + 8], 2.0); + EXPECT_EQ(out_data[4 * row_numel + 4], 2.0); + EXPECT_EQ(out_data[5 * row_numel + 7], 2.0); + EXPECT_EQ(out_data[6 * row_numel + 9], 2.0); + std::unique_ptr tensor1{ + new paddle::framework::Tensor()}; + tensor1->mutable_data( + paddle::framework::make_ddim({height, row_numel}), cpu_place); + functor(ctx, tensor1.get(), 3.0); + paddle::operators::math::SelectedRowsAddToTensor< + paddle::platform::CPUDeviceContext, float> + add_to_tensor_functor; + add_to_tensor_functor(ctx, *output, tensor1.get()); + auto* tensor1_data = tensor1->data(); + // row0: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor1_data[0 * row_numel + 0], 6.0); + // row1: 3.0 + EXPECT_EQ(tensor1_data[1 * row_numel + 1], 3.0); + // row4 : 1.0 + 3.0 + EXPECT_EQ(tensor1_data[4 * row_numel + 6], 4.0); + // row5: 2.0 + 3.0 + EXPECT_EQ(tensor1_data[5 * row_numel + 7], 5.0); + // row6: 3.0 + EXPECT_EQ(tensor1_data[6 * row_numel + 1], 3.0); + // row7: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor1_data[7 * row_numel + 3], 6.0); + // row9: 2.0 + 3.0 + EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a91894ba8935076748aa3a8ded8d8829b88ebb33..8d58e018c345f7c0ca188d4380f675d2bce2a847 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -667,16 +667,17 @@ All parameter, weight, gradient are variables in Paddle. ExecutionStrategy allows the user to more preciously control how to run the program in ParallelExecutor by setting the property. - The available properties include: - use_cuda (bool): Whether to use CUDA or not. Default True. - num_threads (int): The number of threads that used to run the - operators in ParallelExecutor. If it is not set, it will be - set in ParallelExecutor according to the device count. - Default 0. - allow_op_delay (bool): Whether to delay the communication operators - to run. Default False. - num_iteration_per_drop_scope (int): how many iterations between - the two dropping local scopes. Default 100. + Examples: + .. code-block:: python + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 4 + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + exec_strategy=exec_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) )DOC"); @@ -686,19 +687,34 @@ All parameter, weight, gradient are variables in Paddle. [](const ExecutionStrategy &self) { return self.num_threads_; }, [](ExecutionStrategy &self, size_t num_threads) { self.num_threads_ = num_threads; - }) + }, + R"DOC(The type is INT, num_threads represents the size of thread pool that + used to run the operators of the current program in ParallelExecutor. + If :math:`num\_threads=1`, all the operators will execute one by one, + but the order maybe difference between iterations. + If it is not set, it will be set in ParallelExecutor according to the + device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU, + :math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor. + if it is not set, ParallelExecutor will get the cpu count by calling + `multiprocessing.cpu_count()`. Default 0.)DOC") .def_property( "use_cuda", [](const ExecutionStrategy &self) { return self.use_cuda_; }, [](ExecutionStrategy &self, bool use_cuda) { self.use_cuda_ = use_cuda; - }) + }) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may + // make user confuse, because ParallelExecutor has a parameter named + // 'use_cuda' too, in current implementation, ParallelExecutor's + // 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'. .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](ExecutionStrategy &self, bool allow_op_delay) { self.allow_op_delay_ = allow_op_delay; - }) + }, + R"DOC(The type is BOOL, allow_op_delay represents whether to delay the + communication operators to run, it may make the execution faster. + Note that in some models, allow_op_delay may cause program hang. Default False.)DOC") .def_property( "num_iteration_per_drop_scope", [](const ExecutionStrategy &self) { @@ -706,7 +722,19 @@ All parameter, weight, gradient are variables in Paddle. }, [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; - }); + }, + R"DOC(The type is INT, num_iteration_per_drop_scope indicates how + many iterations to clean up the temp variables which + is generated during execution. It may make the execution faster, + because the temp variable's shape maybe the same between two iterations. Default 100. + + NOTES: + 1. If you fetch data when calling the 'run', the ParallelExecutor + will clean up the temp variables at the end of the current iteration. + 2. In some NLP model, it may cause the GPU memory is insufficient, + in this case, you should reduce `num_iteration_per_drop_scope`. + )DOC"); + exec_strategy.def_property( "use_experimental_executor", [](const ExecutionStrategy &self) { @@ -721,20 +749,17 @@ All parameter, weight, gradient are variables in Paddle. BuildStrategy allows the user to more preciously control how to build the SSA Graph in ParallelExecutor by setting the property. - The available properties include: - reduce_strategy (str): There are two reduce strategies, 'AllReduce' - and 'Reduce'. If you want that all parameters will be optimized - on all devices, you can choose 'AllReduce'; if you choose - 'Reduce', all parameters will be evenly allocated to different - devices for optimization, and then broadcast the optimized - parameter to other devices. Default 'AllReduce'. - gradient_scale_strategy (str): There are two ways of defining loss@grad, - 'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor - sets the loss@grad according to the number of devices. If you want - to customize loss@grad, you can choose 'Customized'. - Default 'CoeffNumDevice'. - debug_graphviz_path (str): Whether to write the SSA Graph to file in the - form of graphviz. It is useful for debugging. Default "". + Examples: + .. code-block:: python + + build_strategy = fluid.BuildStrategy() + build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce + + train_exe = fluid.ParallelExecutor(use_cuda=True, + loss_name=loss.name, + build_strategy=build_strategy) + + train_loss, = train_exe.run([loss.name], feed=feed_dict) )DOC"); py::enum_(build_strategy, "ReduceStrategy") @@ -753,31 +778,51 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.reduce_; }, [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { self.reduce_ = strategy; - }) + }, + R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor, + 'AllReduce' and 'Reduce'. If you want that all the parameters' + optimization are done on all devices independently, you should choose 'AllReduce'; + if you choose 'Reduce', all the parameters' optimization will be evenly distributed + to different devices, and then broadcast the optimized parameter to other devices. + In some models, `Reduce` is faster. Default 'AllReduce'. )DOC") .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; - }) + }, + R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in + ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default, + ParallelExecutor sets the :math:`loss@grad` according to the number of devices. + If you want to customize :math:`loss@grad`, you can choose 'Customized'. + Default 'CoeffNumDevice'.)DOC") .def_property( "debug_graphviz_path", [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }) + }, + R"DOC(The type is STR, debug_graphviz_path indicate the path that + writing the SSA Graph to file in the form of graphviz, you. + It is useful for debugging. Default "")DOC") .def_property( "enable_data_balance", [](const BuildStrategy &self) { return self.enable_data_balance_; }, - [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) - .def_property("fuse_elewise_add_act_ops", - [](const BuildStrategy &self) { - return self.fuse_elewise_add_act_ops_; - }, - [](BuildStrategy &self, bool b) { - self.fuse_elewise_add_act_ops_ = b; - }) + [](BuildStrategy &self, bool b) { + self.enable_data_balance_ = b; + }) // FIXME(chengudo): enable_data_balance seems not important + .def_property( + "fuse_elewise_add_act_ops", + [](const BuildStrategy &self) { + return self.fuse_elewise_add_act_ops_; + }, + [](BuildStrategy &self, bool b) { + self.fuse_elewise_add_act_ops_ = b; + }, + R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether + to fuse elementwise_add_op and activation_op, + it may make the execution faster. Default False)DOC") .def("_create_passes_from_strategy", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(); diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 57d272cbfb948840679e80e8db40379c57603113..3f4dd5eb712e738bbee8f93c062375033b8ab2f6 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -31,15 +31,32 @@ BuildStrategy = core.ParallelExecutor.BuildStrategy class ParallelExecutor(object): """ - ParallelExecutor can run program in parallel. + ParallelExecutor is designed for data parallelism, which focuses on distributing + the data across different nodes and every node operates on the data in parallel. + If you use ParallelExecutor to run the current program on GPU, the node means GPU + device, and ParallelExecutor will get the available GPU device automatically on + the current machine. If you use ParallelExecutor to run the current program on CPU, + the node means the CPU device, and you can specify the CPU device number by adding + 'CPU_NUM' environment variable, for example 'CPU_NUM=4', if the environment variable + is not found, ParallelExecutor will call `multiprocessing.cpu_count` to get the number + of CPUs in the system. Args: use_cuda (bool): Whether to use CUDA or not. loss_name (str): The loss name must set in training. Default None. main_program (Program): The program that need to run, if not provided, then default_main_program will be used. Default None. - share_vars_from(ParallelExecutor): If provied, it will share variables + share_vars_from(ParallelExecutor): If provide, it will share variables from the specified ParallelExecutor. Default None. + exec_strategy(ExecutionStrategy): exec_strategy is used to control how to run + the program in ParallelExecutor, for example how many threads are used to + execute the program, how many iterations to clean up the temp variables + which is generated during execution. For more information, please refer + to fluid.ExecutionStrategy. Default None. + build_strategy(BuildStrategy): build_strategy is used to control how to + build the SSA Graph in ParallelExecutor by setting the property, + for example reduce_strategy, gradient_scale_strategy. For more information, + please refer to fluid.BuildStrategy. Default None. num_trainers(int): If greater than 1, NCCL will be initialized with multiple rank of nodes, each node should have same number of GPUs. Distributed training will be enabled then. Default 1.