提交 eb6d9e3b 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize-sum-seq-pooling-op

......@@ -70,7 +70,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"});
if (!node->IsFunction()) return false;
......
......@@ -185,3 +185,4 @@ USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......@@ -26,6 +26,8 @@ nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* PadOp.
*/
class PadOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid transpose op to tensorrt tranpose layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
const std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
const float pad_value = boost::get<float>(op_desc.GetAttr("pad_value"));
nvinfer1::Dims input_shape = input->getDimensions();
int nbDims = input_shape.nbDims;
int pad_size = static_cast<int>(paddings.size());
PADDLE_ENFORCE_GE(nbDims, 2);
PADDLE_ENFORCE_EQ((nbDims + 1) * 2, pad_size);
PADDLE_ENFORCE(pad_value == 0.0, "The pad layer of TRT only support zero.");
nvinfer1::DimsHW pre_pad(paddings[pad_size - 4], paddings[pad_size - 2]);
nvinfer1::DimsHW post_pad(paddings[pad_size - 3], paddings[pad_size - 1]);
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Padding,
*const_cast<nvinfer1::ITensor*>(input),
pre_pad, post_pad);
PADDLE_ENFORCE(layer != nullptr);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
layer->setName(("scale (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(pad, PadOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(PadConverter, main) {
framework::Scope scope;
std::unordered_set<std::string> parameters;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("pad-X", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("pad-Out", nvinfer1::Dims3(3, 3, 5));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("pad");
desc.SetInput("X", {"pad-X"});
desc.SetOutput("Out", {"pad-Out"});
std::vector<int> paddings = {0, 0, 0, 0, 0, 1, 1, 2};
float pad_value = 0.0;
desc.SetAttr("paddings", paddings);
desc.SetAttr("pad_value", pad_value);
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(2);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(pad);
......@@ -3,8 +3,8 @@ add_subdirectory(detail)
endif(NOT WIN32)
function(math_library TARGET)
# math_library is a function to create math library.
# The interface is the same as cc_library.
# math_library is a function to create math library.
# The interface is the same as cc_library.
# But it handle split GPU/CPU code and link some common library.
set(cc_srcs)
set(cu_srcs)
......@@ -53,7 +53,7 @@ cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
math_library(math_function DEPS blas)
math_library(maxouting)
math_library(pooling)
math_library(selected_rows_functor DEPS selected_rows math_function)
math_library(selected_rows_functor DEPS selected_rows math_function blas)
math_library(sequence2batch)
math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function)
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include <set>
#include <unordered_map>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
......@@ -150,6 +149,45 @@ template struct SelectedRowsAddTo<platform::CPUDeviceContext, double>;
template struct SelectedRowsAddTo<platform::CPUDeviceContext, int>;
template struct SelectedRowsAddTo<platform::CPUDeviceContext, int64_t>;
template <typename T>
struct SelectedRowsSumTo<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::SelectedRows*>& input1,
const std::vector<int64_t>& input2_offsets,
framework::SelectedRows* input2) {
// Ensure all selected rows have the same height
size_t size = 0u;
for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
auto& in_rows = (*iter)->rows();
size += in_rows.end() - in_rows.begin();
auto in1_height = (*iter)->height();
PADDLE_ENFORCE_EQ(in1_height, input2->height());
}
// concat rows
std::vector<int64_t> in2_rows;
in2_rows.reserve(in2_rows.size() + size);
for (auto iter = input1.begin(); iter != input1.end(); ++iter) {
const framework::Vector<int64_t>& in_rows = (*iter)->rows();
in2_rows.insert(in2_rows.end(), in_rows.begin(), in_rows.end());
}
input2->set_rows(in2_rows);
auto* in2_value = input2->mutable_value();
auto* in2_data = in2_value->data<T>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
size_t offset = 0u;
for (size_t i = 0u; i != input1.size(); ++i) {
auto& in_value = input1[i]->value();
const auto* in_data = in_value.data<T>();
offset += input2_offsets[i];
blas.VCOPY(in_value.numel(), in_data, in2_data + offset);
}
}
};
template struct SelectedRowsSumTo<platform::CPUDeviceContext, float>;
template struct SelectedRowsSumTo<platform::CPUDeviceContext, double>;
template <typename T>
struct SelectedRowsAddToTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
......@@ -260,8 +298,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
};
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
......
......@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#define INLINE_FOR2(sizei, sizej) \
......@@ -51,6 +54,15 @@ struct SelectedRowsAddTo {
const int64_t input2_offset, framework::SelectedRows* input2);
};
// input2 = [all input in input1] + input2
template <typename DeviceContext, typename T>
struct SelectedRowsSumTo {
void operator()(const DeviceContext& context,
const std::vector<framework::SelectedRows*>& input1,
const std::vector<int64_t>& input2_offsets,
framework::SelectedRows* input2);
};
// input2 = input1 + input2
template <typename DeviceContext, typename T>
struct SelectedRowsAddToTensor {
......@@ -75,6 +87,108 @@ struct MergeAdd {
framework::SelectedRows* output);
};
template <>
struct MergeAdd<platform::CPUDeviceContext, float> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows();
std::vector<int64_t> merge_rows;
merge_rows.reserve(input_rows.size());
std::unordered_map<int64_t, size_t> rows_pos_map;
rows_pos_map.reserve(input_rows.size());
size_t idx = 0u;
for (std::vector<int64_t>::iterator iter = input_rows.begin();
iter != input_rows.end(); ++iter) {
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
rows_pos_map[*iter] = idx++;
merge_rows.emplace_back(*iter);
}
}
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<float>();
auto* input_data = input.value().data<float>();
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_pos_map[input_rows[i]];
float* y = out_data + out_i * input_width;
const float* x = input_data + i * input_width;
blas.AXPY(input_width, 1., x, y);
}
}
};
template <>
struct MergeAdd<platform::CPUDeviceContext, double> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows();
std::vector<int64_t> merge_rows;
merge_rows.reserve(input_rows.size());
std::unordered_map<int64_t, size_t> rows_pos_map;
rows_pos_map.reserve(input_rows.size());
size_t idx = 0u;
for (std::vector<int64_t>::iterator iter = input_rows.begin();
iter != input_rows.end(); ++iter) {
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
rows_pos_map[*iter] = idx++;
merge_rows.emplace_back(*iter);
}
}
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<double>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<double>();
auto* input_data = input.value().data<double>();
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_pos_map[input_rows[i]];
double* y = out_data + out_i * input_width;
const double* x = input_data + i * input_width;
blas.AXPY(input_width, 1., x, y);
}
}
};
template <typename DeviceContext, typename T>
struct Add {
framework::SelectedRows operator()(const DeviceContext& context,
......
......@@ -220,17 +220,98 @@ TEST(selected_rows_functor, cpu_add_to) {
EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0);
}
TEST(selected_rows_functor, cpu_merge_add) {
TEST(selected_rows_functor, cpu_merge_add_float) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
set_const;
functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows{0, 4, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows{
new paddle::framework::SelectedRows(rows, height)};
auto* in_value = selected_rows->mutable_value();
in_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows.size()), row_numel}),
cpu_place);
functor(ctx, in_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
merge_add_functor(ctx, *selected_rows, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
auto* out_data = output->value().data<float>();
EXPECT_EQ(out_data[0 * row_numel], 1.0);
EXPECT_EQ(out_data[1 * row_numel], 2.0);
EXPECT_EQ(out_data[2 * row_numel], 1.0);
}
TEST(selected_rows_functor, cpu_merge_add_int) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext, int>
functor;
int64_t height = 10;
int64_t row_numel = 8;
int64_t row_numel = 10;
std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::vector<int64_t> rows{0, 4, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows{
new paddle::framework::SelectedRows(rows, height)};
auto* in_value = selected_rows->mutable_value();
in_value->mutable_data<int>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows.size()), row_numel}),
cpu_place);
functor(ctx, in_value, 1);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
int>
merge_add_functor;
merge_add_functor(ctx, *selected_rows, output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
auto* out_data = output->value().data<int>();
EXPECT_EQ(out_data[0 * row_numel], 1);
EXPECT_EQ(out_data[1 * row_numel], 2);
EXPECT_EQ(out_data[2 * row_numel], 1);
}
TEST(selected_rows_functor, cpu_sum_to) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
functor;
int64_t height = 10;
int64_t row_numel = 10;
std::vector<int64_t> rows1{0, 4, 7};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
......@@ -238,9 +319,9 @@ TEST(selected_rows_functor, cpu_merge_add) {
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3};
functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
......@@ -248,33 +329,67 @@ TEST(selected_rows_functor, cpu_merge_add) {
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place);
set_const(ctx, in2_value, 1.0);
functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
auto* out_value = output->mutable_value();
// simplely concat two SelectedRows
out_value->mutable_data<float>(paddle::framework::make_ddim({7, 10}),
cpu_place);
paddle::operators::math::SelectedRowsSumTo<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({3, row_numel}));
std::vector<int64_t> ret_rows{2, 3, 5};
EXPECT_EQ(output->rows(), ret_rows);
sum_to_functor;
sum_to_functor(ctx, std::vector<paddle::framework::SelectedRows*>(
{selected_rows1.get(), selected_rows2.get()}),
std::vector<int64_t>({0, in1_value->numel()}), output.get());
auto out_height = output->height();
EXPECT_EQ(out_height, height);
auto& out_rows = output->rows();
// input1 rows
EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(out_rows[1], 4);
EXPECT_EQ(out_rows[2], 7);
// input2 rows
EXPECT_EQ(out_rows[3], 0);
EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
std::cout << out_data[i * row_numel + j] << " ";
}
std::cout << "\n";
}
// input1 value
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
std::unique_ptr<paddle::framework::Tensor> tensor1{
new paddle::framework::Tensor()};
tensor1->mutable_data<float>(
paddle::framework::make_ddim({height, row_numel}), cpu_place);
functor(ctx, tensor1.get(), 3.0);
paddle::operators::math::SelectedRowsAddToTensor<
paddle::platform::CPUDeviceContext, float>
add_to_tensor_functor;
add_to_tensor_functor(ctx, *output, tensor1.get());
auto* tensor1_data = tensor1->data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor1_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor1_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor1_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor1_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor1_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0);
}
......@@ -667,16 +667,17 @@ All parameter, weight, gradient are variables in Paddle.
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
The available properties include:
use_cuda (bool): Whether to use CUDA or not. Default True.
num_threads (int): The number of threads that used to run the
operators in ParallelExecutor. If it is not set, it will be
set in ParallelExecutor according to the device count.
Default 0.
allow_op_delay (bool): Whether to delay the communication operators
to run. Default False.
num_iteration_per_drop_scope (int): how many iterations between
the two dropping local scopes. Default 100.
Examples:
.. code-block:: python
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 4
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
exec_strategy=exec_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC");
......@@ -686,19 +687,34 @@ All parameter, weight, gradient are variables in Paddle.
[](const ExecutionStrategy &self) { return self.num_threads_; },
[](ExecutionStrategy &self, size_t num_threads) {
self.num_threads_ = num_threads;
})
},
R"DOC(The type is INT, num_threads represents the size of thread pool that
used to run the operators of the current program in ParallelExecutor.
If :math:`num\_threads=1`, all the operators will execute one by one,
but the order maybe difference between iterations.
If it is not set, it will be set in ParallelExecutor according to the
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
if it is not set, ParallelExecutor will get the cpu count by calling
`multiprocessing.cpu_count()`. Default 0.)DOC")
.def_property(
"use_cuda",
[](const ExecutionStrategy &self) { return self.use_cuda_; },
[](ExecutionStrategy &self, bool use_cuda) {
self.use_cuda_ = use_cuda;
})
}) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
// make user confuse, because ParallelExecutor has a parameter named
// 'use_cuda' too, in current implementation, ParallelExecutor's
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
.def_property(
"allow_op_delay",
[](const ExecutionStrategy &self) { return self.allow_op_delay_; },
[](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay;
})
},
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
communication operators to run, it may make the execution faster.
Note that in some models, allow_op_delay may cause program hang. Default False.)DOC")
.def_property(
"num_iteration_per_drop_scope",
[](const ExecutionStrategy &self) {
......@@ -706,7 +722,19 @@ All parameter, weight, gradient are variables in Paddle.
},
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
});
},
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
will clean up the temp variables at the end of the current iteration.
2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC");
exec_strategy.def_property(
"use_experimental_executor",
[](const ExecutionStrategy &self) {
......@@ -721,20 +749,17 @@ All parameter, weight, gradient are variables in Paddle.
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
The available properties include:
reduce_strategy (str): There are two reduce strategies, 'AllReduce'
and 'Reduce'. If you want that all parameters will be optimized
on all devices, you can choose 'AllReduce'; if you choose
'Reduce', all parameters will be evenly allocated to different
devices for optimization, and then broadcast the optimized
parameter to other devices. Default 'AllReduce'.
gradient_scale_strategy (str): There are two ways of defining loss@grad,
'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor
sets the loss@grad according to the number of devices. If you want
to customize loss@grad, you can choose 'Customized'.
Default 'CoeffNumDevice'.
debug_graphviz_path (str): Whether to write the SSA Graph to file in the
form of graphviz. It is useful for debugging. Default "".
Examples:
.. code-block:: python
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
build_strategy=build_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC");
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
......@@ -753,31 +778,51 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
self.reduce_ = strategy;
})
},
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
'AllReduce' and 'Reduce'. If you want that all the parameters'
optimization are done on all devices independently, you should choose 'AllReduce';
if you choose 'Reduce', all the parameters' optimization will be evenly distributed
to different devices, and then broadcast the optimized parameter to other devices.
In some models, `Reduce` is faster. Default 'AllReduce'. )DOC")
.def_property(
"gradient_scale_strategy",
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
self.gradient_scale_ = strategy;
})
},
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default,
ParallelExecutor sets the :math:`loss@grad` according to the number of devices.
If you want to customize :math:`loss@grad`, you can choose 'Customized'.
Default 'CoeffNumDevice'.)DOC")
.def_property(
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
self.debug_graphviz_path_ = path;
})
},
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC")
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
})
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC")
.def("_create_passes_from_strategy",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy();
......
......@@ -31,15 +31,32 @@ BuildStrategy = core.ParallelExecutor.BuildStrategy
class ParallelExecutor(object):
"""
ParallelExecutor can run program in parallel.
ParallelExecutor is designed for data parallelism, which focuses on distributing
the data across different nodes and every node operates on the data in parallel.
If you use ParallelExecutor to run the current program on GPU, the node means GPU
device, and ParallelExecutor will get the available GPU device automatically on
the current machine. If you use ParallelExecutor to run the current program on CPU,
the node means the CPU device, and you can specify the CPU device number by adding
'CPU_NUM' environment variable, for example 'CPU_NUM=4', if the environment variable
is not found, ParallelExecutor will call `multiprocessing.cpu_count` to get the number
of CPUs in the system.
Args:
use_cuda (bool): Whether to use CUDA or not.
loss_name (str): The loss name must set in training. Default None.
main_program (Program): The program that need to run, if not provided,
then default_main_program will be used. Default None.
share_vars_from(ParallelExecutor): If provied, it will share variables
share_vars_from(ParallelExecutor): If provide, it will share variables
from the specified ParallelExecutor. Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to control how to run
the program in ParallelExecutor, for example how many threads are used to
execute the program, how many iterations to clean up the temp variables
which is generated during execution. For more information, please refer
to fluid.ExecutionStrategy. Default None.
build_strategy(BuildStrategy): build_strategy is used to control how to
build the SSA Graph in ParallelExecutor by setting the property,
for example reduce_strategy, gradient_scale_strategy. For more information,
please refer to fluid.BuildStrategy. Default None.
num_trainers(int): If greater than 1, NCCL will be initialized with
multiple rank of nodes, each node should have same number of GPUs.
Distributed training will be enabled then. Default 1.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册