提交 8580b7a1 编写于 作者: P peizhilin

Merge remote-tracking branch 'upstream/develop' into windows/build

...@@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', ...@@ -93,11 +93,11 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized',
paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)) paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None))
paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)) paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None))
paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times'], varargs=None, keywords=None, defaults=(0, False)) paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False))
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) { ...@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", {"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout", "split"}); "elementwise_add", "dropout", "split", "prelu", "conv2d_transpose"});
if (!node->IsOp()) return false; if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) { if (teller_set.count(node->Op()->Type())) {
......
...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat); ...@@ -549,4 +549,6 @@ USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad); USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split); USE_TRT_CONVERTER(split);
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
#endif #endif
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
nv_library(tensorrt_converter nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc pad_op.cc split_op.cc prelu_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
...@@ -16,7 +16,7 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc ...@@ -16,7 +16,7 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine activation_op SERIAL)
nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op conv_transpose_op SERIAL)
nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL)
nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
...@@ -33,4 +33,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc ...@@ -33,4 +33,7 @@ nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL) DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL) split_op concat_op SERIAL)
nv_test(test_trt_prelu_op SRCS test_prelu_op.cc prelu_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
prelu_op SERIAL)
...@@ -18,92 +18,139 @@ namespace paddle { ...@@ -18,92 +18,139 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
bool to_skip_merging_optimize(TensorRTEngine* engine_, bool to_skip_merging_optimize(TensorRTEngine* engine,
const std::vector<int>& filters, const std::vector<int>& filters,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
std::string input_name) { std::string input_name) {
if (engine_->itensor_quote_num[input_name] > 0) { if (engine->itensor_quote_num[input_name] > 0) {
return true; return true;
} }
if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 &&
strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0)
engine_->itensor_quote_num[input_name] += 1; engine->itensor_quote_num[input_name] += 1;
return false; return false;
} }
template <typename RegistFunc, typename SetDilationFunc>
void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode,
RegistFunc fadd_layer, SetDilationFunc fset_dilation,
const std::string& name) {
VLOG(3) << "convert a fluid " << name << " op to tensorrt layer without bias";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1);
PADDLE_ENFORCE(engine != nullptr);
auto* X = engine->GetITensor(op_desc.Input("Input").front());
// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place;
std::unique_ptr<framework::LoDTensor> weight_tensor(
new framework::LoDTensor());
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get());
auto* weight_data = weight_tensor->mutable_data<float>(platform::CPUPlace());
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL);
const int n_output = weight_tensor->dims()[0];
const int n_input = weight_tensor->dims()[1];
const int filter_h = weight_tensor->dims()[2];
const int filter_w = weight_tensor->dims()[3];
const int groups = boost::get<int>(op_desc.GetAttr("groups"));
const std::vector<int> dilations =
boost::get<std::vector<int>>(op_desc.GetAttr("dilations"));
const std::vector<int> strides =
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
const std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_tensor->numel())};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input,
nv_ksize, weight, bias);
PADDLE_ENFORCE(layer != nullptr);
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setNbGroups(groups);
// set dilations
fset_dilation(layer, nv_dilations);
auto output_name = op_desc.Output("Output").front();
layer->setName((name + " (Output: " + output_name + ")").c_str());
engine->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str());
engine->SetITensor(output_name, layer->getOutput(0));
if (test_mode ||
to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings,
op_desc.Input("Input").front())) {
engine->DeclareOutput(output_name);
}
}
class Conv2dOpConverter : public OpConverter { class Conv2dOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid conv2d op to tensorrt conv layer without bias"; ConvertConv2d(
engine_, op, scope, test_mode,
framework::OpDesc op_desc(op, nullptr); [&](nvinfer1::ITensor* inputs, int n_output, /* Conv output maps */
PADDLE_ENFORCE_EQ(op_desc.Input("Input").size(), 1); int n_input, /* Conv input maps */
PADDLE_ENFORCE_EQ(op_desc.Input("Filter").size(), 1); // Y is a weight nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer =
auto* X = engine_->GetITensor(op_desc.Input("Input").front()); TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
ksize, weight.get(), bias.get());
// Declare weights return layer;
auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); },
PADDLE_ENFORCE_NOT_NULL(Y_v); [](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); layer->setDilation(dilations);
},
platform::CPUPlace cpu_place; "conv2d");
std::unique_ptr<framework::LoDTensor> weight_tensor( }
new framework::LoDTensor()); };
weight_tensor->Resize(Y_t->dims());
TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); class Deconv2dOpConverter : public OpConverter {
public:
auto* weight_data = void operator()(const framework::proto::OpDesc& op,
weight_tensor->mutable_data<float>(platform::CPUPlace()); const framework::Scope& scope, bool test_mode) override {
ConvertConv2d(
PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); engine_, op, scope, test_mode,
const int n_output = weight_tensor->dims()[0]; [&](nvinfer1::ITensor* inputs, int n_output, /* Deconv input maps */
const int filter_h = weight_tensor->dims()[2]; int n_input, /* Deconv output maps */
const int filter_w = weight_tensor->dims()[3]; nvinfer1::DimsHW& ksize, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IDeconvolutionLayer* {
const int groups = boost::get<int>(op_desc.GetAttr("groups")); auto* layer =
const std::vector<int> dilations = TRT_ENGINE_ADD_LAYER(engine_, Deconvolution, *inputs, n_input,
boost::get<std::vector<int>>(op_desc.GetAttr("dilations")); ksize, weight.get(), bias.get());
const std::vector<int> strides = return layer;
boost::get<std::vector<int>>(op_desc.GetAttr("strides")); },
const std::vector<int> paddings = [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) {
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); PADDLE_ENFORCE(
dilations.d[0] == 1 && dilations.d[1] == 1,
nvinfer1::DimsHW nv_ksize(filter_h, filter_w); "Dilations must be (1, 1) for tensorRT, but given (%d, %d)",
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]); dilations.d[0], dilations.d[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]); },
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); "conv2d_transpose");
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
weight_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *const_cast<nvinfer1::ITensor*>(X), n_output,
nv_ksize, weight.get(), bias.get());
PADDLE_ENFORCE(layer != nullptr);
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setDilation(nv_dilations);
layer->setNbGroups(groups);
auto output_name = op_desc.Output("Output").front();
layer->setName(("conv2d (Output: " + output_name + ")").c_str());
engine_->weight_map[op_desc.Input("Filter").front()] =
std::move(weight_tensor);
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode ||
to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides,
paddings, op_desc.Input("Input").front())) {
engine_->DeclareOutput(output_name);
}
} }
}; };
...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter { ...@@ -112,3 +159,4 @@ class Conv2dOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter);
...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -34,7 +34,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* X = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_x = X->getDimensions();
PADDLE_ENFORCE(dims_x.nbDims >= 3); PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.",
dims_x.nbDims);
auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); auto* Y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(Y_v); PADDLE_ENFORCE_NOT_NULL(Y_v);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* PRelu converter from fluid to tensorRT.
*/
class PReluOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
PADDLE_ENFORCE(input_num == 1);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get output
size_t output_num = op_desc.Output("Out").size();
PADDLE_ENFORCE(output_num == 1);
// Get attrs
std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
//
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
PADDLE_ENFORCE_NOT_NULL(alpha_var);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
platform::CUDAPlace place;
std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
new framework::LoDTensor());
alpha_tensor_device->Resize(alpha_tensor->dims());
TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
// Transform alpha to TensorRTEngine::Weight
TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data),
alpha_tensor_device->numel());
PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
// keep alpha tensor to avoid release it's memory
engine_->weight_map[op_desc.Input("Alpha")[0]] =
std::move(alpha_tensor_device);
std::string layer_name = "prelu (Output: ";
auto output_name = op_desc.Output("Out")[0];
layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(0));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
layer->setName((layer_name + ")").c_str());
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
...@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter { ...@@ -26,7 +26,7 @@ class SplitOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert a fluid split op to tensorrt split layer"; VLOG(4) << "convert a fluid split op to tensorrt split layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
USE_OP(conv2d);
USE_OP(conv2d_transpose);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
...@@ -51,7 +54,37 @@ TEST(conv2d_op, test) { ...@@ -51,7 +54,37 @@ TEST(conv2d_op, test) {
validator.Execute(3); validator.Execute(3);
} }
TEST(conv2d_transpose_op, test) {
std::unordered_set<std::string> parameters({"deconv2d-Y"});
framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
validator.DeclInputVar("deconv2d-X", nvinfer1::Dims3(3, 5, 5));
validator.DeclParamVar("deconv2d-Y", nvinfer1::Dims4(3, 2, 3, 3));
validator.DeclOutputVar("deconv2d-Out", nvinfer1::Dims3(2, 5, 5));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("conv2d_transpose");
desc.SetInput("Input", {"deconv2d-X"});
desc.SetInput("Filter", {"deconv2d-Y"});
desc.SetOutput("Output", {"deconv2d-Out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("dilations", dilations);
desc.SetAttr("groups", groups);
validator.SetOp(*desc.Proto());
validator.Execute(3);
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(conv2d);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(prelu_op, test_channel_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("channel"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_element_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("element"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
TEST(prelu_op, test_scalar) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("prelu");
desc.SetInput("X", {"prelu_input"});
desc.SetInput("Alpha", {"prelu_alpha"});
desc.SetOutput("Out", {"prelu_out"});
desc.SetAttr("mode", std::string("all"));
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// USE_OP(prelu);
USE_CPU_ONLY_OP(prelu);
...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, ...@@ -200,7 +200,8 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
Buffer &TensorRTEngine::buffer(const std::string &name) { Buffer &TensorRTEngine::buffer(const std::string &name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name); auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s",
name);
auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset]; return buffers_[slot_offset];
} }
......
...@@ -40,6 +40,7 @@ class TensorRTEngine : public EngineBase { ...@@ -40,6 +40,7 @@ class TensorRTEngine : public EngineBase {
// Weight is model parameter. // Weight is model parameter.
class Weight { class Weight {
public: public:
Weight() = default;
Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) { Weight(nvinfer1::DataType dtype, void* value, size_t num_elem) {
w_.type = dtype; w_.type = dtype;
w_.values = value; w_.values = value;
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce) nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu prelu_op_plugin.cu DEPS enforce)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__global__ void PReluChannelWiseKernel(const float *input, const float *alpha,
float *output, int channel,
size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float *out = output + offset;
float scale = alpha[blockIdx.x % channel];
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
__global__ void PReluElementWiseKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
const float *scale = alpha + offset;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale[i] * x;
}
}
__global__ void PReluScalarKernel(const float *input, const float *alpha,
float *output, size_t spatial_size) {
size_t offset = blockIdx.x * spatial_size;
const float *in = input + offset;
float scale = *alpha;
float *out = output + offset;
for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float x = in[i];
out[i] = (x > 0) ? x : scale * x;
}
}
static inline void PReluChannelWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, dims.d[0], spatial_size);
}
static inline void PReluElementWise(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size,
const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
static inline void PReluScalar(cudaStream_t stream, const float *input,
const float *alpha, float *output,
int batch_size, const nvinfer1::Dims &dims) {
size_t unroll = batch_size * dims.d[0];
size_t spatial_size = dims.d[1] * dims.d[2];
CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, spatial_size);
}
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const &input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
return output_dims;
}
int PReluPlugin::enqueue(int batchSize, const void *const *inputs,
void **outputs, void *workspace, cudaStream_t stream) {
// input dims is CHW.
const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
float *output = reinterpret_cast<float **>(outputs)[0];
if (mode_ == "channel") {
PReluChannelWise(stream, input, alpha, output, batchSize, input_dims);
} else if (mode_ == "element") {
PReluElementWise(stream, input, alpha, output, batchSize, input_dims);
} else {
PReluScalar(stream, input, alpha, output, batchSize, input_dims);
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PReluPlugin : public PluginTensorRT {
TensorRTEngine::Weight alpha_;
std::string mode_;
protected:
size_t getSerializationSize() override {
// return getBaseSerializationSize(alpha_) + SerializedSize(mode_);
return 0;
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) override {
// serializeBase(buffer);
// SerializeValue(&buffer, alpha_);
// SerializeValue(&buffer, mode_);
}
public:
PReluPlugin(TensorRTEngine::Weight const &alpha, std::string const &mode)
: alpha_(alpha), mode_(mode) {}
// It was used for tensorrt deserialization.
// It should not be called by users.
PReluPlugin(void const *serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength);
// DeserializeValue(&serialData, &serialLength, &alpha_);
// DeserializeValue(&serialData, &serialLength, &mode_);
}
PReluPlugin *clone() const override { return new PReluPlugin(alpha_, mode_); }
const char *getPluginType() const override { return "prelu"; }
int getNbOutputs() const override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
int nbInputDims) override;
int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -300,7 +300,6 @@ if (NOT WIN32) ...@@ -300,7 +300,6 @@ if (NOT WIN32)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
endif(NOT WIN32) endif(NOT WIN32)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op) op_library(unsqueeze_op DEPS reshape_op)
...@@ -309,6 +308,7 @@ op_library(flatten_op DEPS reshape_op) ...@@ -309,6 +308,7 @@ op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory) op_library(fake_quantize_op DEPS memory)
op_library(nce_op DEPS sampler)
if (NOT WIN32) if (NOT WIN32)
op_library(crf_decoding_op DEPS jit_kernel) op_library(crf_decoding_op DEPS jit_kernel)
op_library(fusion_lstm_op DEPS jit_kernel) op_library(fusion_lstm_op DEPS jit_kernel)
...@@ -331,6 +331,14 @@ op_library(load_combine_op DEPS lod_tensor) ...@@ -331,6 +331,14 @@ op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat_and_split) op_library(concat_op DEPS concat_and_split)
op_library(tensor_array_to_tensor_op DEPS concat_op) op_library(tensor_array_to_tensor_op DEPS concat_op)
set(DEPS_OPS ${DEPS_OPS} warpctc_op)
if (WITH_GPU)
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
endif()
endif()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
......
...@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
// set the first dim to -1 in compile time
if (!ctx->IsRuntime()) {
out_shape[0] = x_dims[0];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) { if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
...@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("expand_times"); ctx->Attrs().Get<std::vector<int>>("expand_times");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
for (size_t i = 0; i < expand_times.size(); ++i) { size_t start_pos = 0u;
if (!ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0],
"The first dimension size of Input(Out@GRAD) should be "
"equal to the crroresponding dimension size of Input(X)");
start_pos = 1u;
}
for (size_t i = start_pos; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be " "Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension " "equal to multiplication of crroresponding dimension "
......
...@@ -41,6 +41,7 @@ math_library(cross_entropy) ...@@ -41,6 +41,7 @@ math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
math_library(depthwise_conv) math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
math_library(sampler)
if (NOT WIN32) # windows do not support avx functions yet. if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and ...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace random { namespace operators {
namespace math {
Sampler::~Sampler() {} Sampler::~Sampler() {}
UniformSampler::UniformSampler(int64 range) UniformSampler::UniformSampler(int64_t range, unsigned int seed)
: Sampler(range), inv_range_(1.0 / range) { : Sampler(range, seed), inv_range_(1.0 / (range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
} }
UniformSampler::UniformSampler(int64 range, unsigned int seed) int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
: Sampler(range, seed), inv_range_(1.0 / range) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
float UniformSampler::Probability(int64 value) const { return inv_range_; } float UniformSampler::Probability(int64_t value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64 range) LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed)
: Sampler(range), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
}
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
: Sampler(range, seed), log_range_(log(range + 1)) { : Sampler(range, seed), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
} }
int64 LogUniformSampler::Sample() const {
int64_t LogUniformSampler::Sample() const {
// Got Log Uniform distribution from uniform distribution by // Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method // inverse_transform_sampling method
// More details: // More details:
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
const int64 value = const int64_t value =
static_cast<int64>(exp((*dist_)(*random_engine_) * log_range_)) - 1; static_cast<int64_t>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
// Mathematically, value should be <= range_, but might not be due to some // Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_. // floating point roundoff, so we mod by range_.
return value % range_; return value % range_;
} }
float LogUniformSampler::Probability(int64 value) const { float LogUniformSampler::Probability(int64_t value) const {
// Given f(x) = 1/[(x+1) * log_range_] // Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1) // The value's probability is integral of f(x) from value to (value + 1)
// More details: // More details:
...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const { ...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_; return (log((value + 2.0) / (value + 1.0))) / log_range_;
} }
} // namespace random CustomSampler::CustomSampler(int64_t range, const float* probabilities,
unsigned int seed)
: Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0
auto little = littles.front();
(*alias_probs_)[little.first] = 1.0;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
}
int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) {
return (*alias_)[index];
} else {
return index;
}
}
float CustomSampler::Probability(int64_t value) const {
return (*probs_)[value];
}
} // namespace math
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <random> #include <random>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,14 +29,14 @@ namespace math { ...@@ -27,14 +29,14 @@ namespace math {
*/ */
class Sampler { class Sampler {
public: public:
explicit Sampler(int64_t range) : range_(range) { explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
PADDLE_ENFORCE_GT(range, 0); // PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0.");
std::random_device r; if (seed == 0) {
seed_ = r(); std::random_device r;
} seed_ = r();
explicit Sampler(int64_t range, unsigned int seed) } else {
: range_(range), seed_(seed) { seed_ = seed;
PADDLE_ENFORCE_GT(range, 0); }
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
...@@ -42,7 +44,7 @@ class Sampler { ...@@ -42,7 +44,7 @@ class Sampler {
// The probability that a single call to Sample() returns the given value. // The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0; virtual float Probability(int64_t value) const = 0;
int64 range() { return range_; } int64_t range() { return range_; }
protected: protected:
const int64_t range_; const int64_t range_;
...@@ -56,13 +58,11 @@ class Sampler { ...@@ -56,13 +58,11 @@ class Sampler {
*/ */
class UniformSampler : public Sampler { class UniformSampler : public Sampler {
public: public:
explicit UniformSampler(int64_t range); explicit UniformSampler(int64_t range, unsigned int seed = 0UL);
explicit UniformSampler(int64_t range, unsigned int seed);
~UniformSampler() override {} ~UniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler { ...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler {
*/ */
class LogUniformSampler : public Sampler { class LogUniformSampler : public Sampler {
public: public:
explicit LogUniformSampler(int64_t range); explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL);
explicit LogUniformSampler(int64_t range, unsigned int seed);
~LogUniformSampler() override {} ~LogUniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler { ...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler {
std::shared_ptr<std::uniform_real_distribution<>> dist_; std::shared_ptr<std::uniform_real_distribution<>> dist_;
}; };
/**
* Sample integers from [0, range) from custom distribution.
*/
class CustomSampler : public Sampler {
public:
explicit CustomSampler(int64_t range, const float* probabilities,
unsigned int seed = 0UL);
~CustomSampler() override {}
int64_t Sample() const override;
float Probability(int64_t value) const override;
private:
std::shared_ptr<std::vector<float>> alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_;
std::shared_ptr<std::vector<float>> probs_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"each sample. And it is a dispensable input. The default value of " "each sample. And it is a dispensable input. The default value of "
"sample is 1.") "sample is 1.")
.AsDispensable(); .AsDispensable();
AddInput(
"CustomDistribution",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost", AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits", AddOutput("SampleLogits",
...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("num_neg_samples", AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.") "The number of negative classes. The default value is 10.")
.SetDefault(10); .SetDefault(10);
AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0);
AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.")
.SetDefault(0);
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes " "in this list wiil be used as negative classes "
......
...@@ -19,29 +19,28 @@ limitations under the License. */ ...@@ -19,29 +19,28 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using Sampler = math::Sampler;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext& context) { void PrepareSamples(const framework::ExecutionContext& context,
Sampler* sampler) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
int num_total_classes = context.Attr<int>("num_total_classes"); // int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
std::vector<int> custom_neg_classes = std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes"); context.Attr<std::vector<int>>("custom_neg_classes");
// random machine
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { ...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) {
} else { } else {
for (; j < sample_labels_dims[1]; ++j) { for (; j < sample_labels_dims[1]; ++j) {
// TODO(wanghaoshuang): support more distribution sampling // TODO(wanghaoshuang): support more distribution sampling
sample_labels_data[index++] = rand(rng); sample_labels_data[index++] = sampler->Sample();
} }
} }
} }
...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T> ...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T>
class NCEKernel : public framework::OpKernel<T> { class NCEKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PrepareSamples<DeviceContext, T>(context); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t* sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> {
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_total_classes = context.Attr<int>("num_total_classes");
int64_t num_true_class = 1; int64_t num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples; int64_t sampled_labels_num = sample_labels->dims()[1];
// T b = 1. / num_total_classes * num_neg_samples;
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> {
} }
// forward cost // forward cost
for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
int64_t j = 0;
out_data[i] = 0; out_data[i] = 0;
T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
// for true classes for (int64_t j = 0; j < sampled_labels_num; ++j) {
for (; j < num_true_class; ++j) { int64_t target = sample_labels_data[i * sampled_labels_num + j];
T o = sample_out_data[i * sample_out->dims()[1] + j]; T o = sample_out_data[i * sampled_labels_num + j];
T cost = -log(o / (o + b)); float b = sampler->Probability(target) * num_neg_samples;
out_data[i] += w * cost; T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
}
// for sampled neg classes
for (; j < sample_labels->dims()[1]; ++j) {
T o = sample_out_data[i * sample_out->dims()[1] + j];
T cost = -log(b / (o + b));
out_data[i] += w * cost; out_data[i] += w * cost;
} }
} }
delete sampler;
} }
}; };
...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> {
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples;
int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
// T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
int64_t label_idx = i % sample_labels->dims()[1];
int64_t sample_idx = i / sample_labels->dims()[1];
float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
T o = sample_out_data[i]; T o = sample_out_data[i];
T w = sample_weight == nullptr T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
? 1 sample_grad_data[i] = label_idx < num_true_class
: sample_weight_data[i / sample_labels->dims()[1]];
sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
? w * (b / (o + b)) * (o - 1) ? w * (b / (o + b)) * (o - 1)
: w * (o * (1 - o) / (o + b)); : w * (o * (1 - o) / (o + b));
sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; sample_grad_data[i] *= d_out_data[sample_idx];
} }
// get d_bias // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias")); auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
} }
} }
delete sampler;
} }
}; };
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/warpctc_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
#if CUDNN_VERSION >= 7001
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedCTCLossDescriptor = platform::ScopedCTCLossDescriptor;
using DataLayout = platform::DataLayout;
template <typename DeviceContext, typename T>
class CudnnCTCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// =====================Copied code from warpctc===========================
auto* logits = ctx.Input<LoDTensor>("Logits");
auto* label = ctx.Input<LoDTensor>("Label");
auto* warpctc_grad = ctx.Output<LoDTensor>("WarpCTCGrad");
auto* loss = ctx.Output<LoDTensor>("Loss");
const size_t level = 0;
auto logits_lod = framework::ToAbsOffset(logits->lod());
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(logits_dims[0],
static_cast<int64_t>(logits_lod[level].back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
auto label_lod = framework::ToAbsOffset(label->lod());
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
const size_t num_sequences = logits_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
PADDLE_ENFORCE_LE(num_sequences, 256,
"The labelLengths must less than 256 for cudnn call.");
const size_t sequence_width = logits->numel() / logits_dims[0];
auto loss_dims =
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// NOTE: cudnn takes softmax input, calculate softmax first, then do padding
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
LoDTensor softmax_logits;
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
softmax_logits.set_lod(logits_lod);
int rank = logits->dims().size();
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
// ctc needs sequences data stored in transposed padding format
// logits and grad using padding data of layout 'TNC'
// T: max_sequence_length
// N: batch_size (num_sequences)
// C: width
LoDTensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod[level]);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
} else {
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), softmax_logits,
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
math::kLengthBatchWidth);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
warpctc_logits_lengths[i] =
logits_lod[level][i + 1] - logits_lod[level][i];
}
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), warpctc_grad,
static_cast<T>(0));
Tensor warpctc_label;
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// ========================================================================
ScopedTensorDescriptor logits_desc;
ScopedTensorDescriptor grad_desc;
ScopedCTCLossDescriptor ctcloss_desc;
// layout here doesn't have effect.
DataLayout layout = DataLayout::kNCHW;
auto cu_logits_desc = logits_desc.descriptor<T>(
layout, framework::vectorize2int(warpctc_logits.dims()));
auto cu_grad_desc = grad_desc.descriptor<T>(
layout, framework::vectorize2int(warpctc_grad->dims()));
auto cu_ctcloss_desc = ctcloss_desc.descriptor<T>();
auto handle = dev_ctx.cudnn_handle();
size_t workspace_size;
CUDNN_ENFORCE(platform::dynload::cudnnGetCTCLossWorkspaceSize(
handle, cu_logits_desc, cu_grad_desc, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
loss_data, cu_grad_desc, warpctc_grad_data,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace,
workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
}
};
template <typename DeviceContext, typename T>
class CudnnCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss_grad_data,
logits_grad);
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if CUDNN_VERSION >= 7001
REGISTER_OP_KERNEL(
warpctc, CUDNN, plat::CUDAPlace,
ops::CudnnCTCKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_KERNEL(
warpctc_grad, CUDNN, plat::CUDAPlace,
ops::CudnnCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
#endif
...@@ -14,6 +14,10 @@ limitations under the License. */ ...@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/warpctc_op.h" #include "paddle/fluid/operators/warpctc_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context()); ctx.device_context(), layout_, library_);
} }
}; };
...@@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
"normalize the gradients by the number of time-step, " "normalize the gradients by the number of time-step, "
"which is also the sequence's length.") "which is also the sequence's length.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_cudnn",
"(bool, default: false), whether to "
"use cudnn kernel.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
An operator integrating the open-source An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in [warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
......
...@@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { ...@@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
return use_cudnn; return use_cudnn;
} }
#if CUDNN_VERSION >= 7001
class ScopedCTCLossDescriptor {
public:
ScopedCTCLossDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateCTCLossDescriptor(&desc_));
}
~ScopedCTCLossDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyCTCLossDescriptor(desc_));
}
template <typename T>
inline cudnnCTCLossDescriptor_t descriptor() {
PADDLE_ENFORCE(
dynload::cudnnSetCTCLossDescriptor(desc_, CudnnDataType<T>::type));
return desc_;
}
private:
cudnnCTCLossDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor);
};
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \ __macro(cudnnSetConvolutionGroupCount); \
__macro(cudnnSetConvolutionMathType); __macro(cudnnSetConvolutionMathType); \
__macro(cudnnCreateCTCLossDescriptor); \
__macro(cudnnDestroyCTCLossDescriptor); \
__macro(cudnnGetCTCLossDescriptor); \
__macro(cudnnSetCTCLossDescriptor); \
__macro(cudnnGetCTCLossWorkspaceSize); \
__macro(cudnnCTCLoss);
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
......
...@@ -404,8 +404,8 @@ def open_recordio_file(filename, ...@@ -404,8 +404,8 @@ def open_recordio_file(filename,
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_var.persistable = True
main_prog_var = _copy_reader_var_( main_prog_var = _copy_reader_var_(default_main_program().current_block(),
default_main_program().current_block(), startup_var) startup_var)
if pass_num > 1: if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
......
...@@ -4193,7 +4193,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4193,7 +4193,7 @@ def ctc_greedy_decoder(input, blank, name=None):
return ctc_out return ctc_out
def warpctc(input, label, blank=0, norm_by_times=False): def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
""" """
An operator integrating the open source Warp-CTC library An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc) (https://github.com/baidu-research/warp-ctc)
...@@ -4218,6 +4218,7 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -4218,6 +4218,7 @@ def warpctc(input, label, blank=0, norm_by_times=False):
by the number of time-step, which is also the sequence's length. by the number of time-step, which is also the sequence's length.
There is no need to normalize the gradients if warpctc layer was There is no need to normalize the gradients if warpctc layer was
follewed by a mean_op. follewed by a mean_op.
use_cudnn (bool, default false): Whether to use cudnn.
Returns: Returns:
Variable: The Connectionist Temporal Classification (CTC) loss, Variable: The Connectionist Temporal Classification (CTC) loss,
...@@ -4241,8 +4242,11 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -4241,8 +4242,11 @@ def warpctc(input, label, blank=0, norm_by_times=False):
'Label': [label]}, 'Label': [label]},
outputs={'WarpCTCGrad': [grad_out], outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]}, 'Loss': [loss_out]},
attrs={'blank': blank, attrs={
'norm_by_times': norm_by_times}) 'blank': blank,
'norm_by_times': norm_by_times,
'use_cudnn': use_cudnn
})
return loss_out return loss_out
...@@ -4315,7 +4319,10 @@ def nce(input, ...@@ -4315,7 +4319,10 @@ def nce(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
num_neg_samples=None, num_neg_samples=None,
name=None): name=None,
sampler="uniform",
custom_dist=None,
seed=0):
""" """
${comment} ${comment}
...@@ -4338,6 +4345,14 @@ def nce(input, ...@@ -4338,6 +4345,14 @@ def nce(input,
num_neg_samples (int): ${num_neg_samples_comment} num_neg_samples (int): ${num_neg_samples_comment}
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes].
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int): The seed used in sampler. default: 0.
Returns: Returns:
Variable: The output nce loss. Variable: The output nce loss.
...@@ -4367,6 +4382,16 @@ def nce(input, ...@@ -4367,6 +4382,16 @@ def nce(input,
loss = layers.nce(input=embs, label=words[label_word], loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w', num_total_classes=dict_size, param_attr='nce.w',
bias_attr='nce.b') bias_attr='nce.b')
#or use custom distribution
dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32"))
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w',
bias_attr='nce.b',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
""" """
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable) assert isinstance(input, Variable)
...@@ -4401,9 +4426,31 @@ def nce(input, ...@@ -4401,9 +4426,31 @@ def nce(input,
else: else:
num_neg_samples = int(num_neg_samples) num_neg_samples = int(num_neg_samples)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if sampler == "uniform":
sampler = 0
elif sampler == "log_uniform":
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
sampler = 2
else:
raise Exception("Unsupported sampler type.")
attrs = { attrs = {
'num_total_classes': int(num_total_classes), 'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples 'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler
} }
helper.append_op( helper.append_op(
......
...@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase): ...@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.infer_shape(block) mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
def test_expand_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [-1, 20]
expand_times = [3, 1]
# prepare input/output
x1 = block.var(six.b("x"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("expand")
sum_op_desc.set_input("X", ["x"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc._set_attr('expand_times', expand_times)
sum_op_desc.check_attrs()
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -68,7 +68,9 @@ class TestNCE(OpTest): ...@@ -68,7 +68,9 @@ class TestNCE(OpTest):
self.attrs = { self.attrs = {
'num_total_classes': num_classes, 'num_total_classes': num_classes,
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)) 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0,
'sampler': 0
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
......
...@@ -183,6 +183,7 @@ class TestWarpCTCOp(OpTest): ...@@ -183,6 +183,7 @@ class TestWarpCTCOp(OpTest):
self.labels_lod = [[3, 1, 4, 4]] self.labels_lod = [[3, 1, 4, 4]]
self.blank = self.num_classes - 1 self.blank = self.num_classes - 1
self.norm_by_times = False self.norm_by_times = False
self.use_cudnn = False
def setUp(self): def setUp(self):
self.op_type = "warpctc" self.op_type = "warpctc"
...@@ -215,7 +216,11 @@ class TestWarpCTCOp(OpTest): ...@@ -215,7 +216,11 @@ class TestWarpCTCOp(OpTest):
"Label": (labels, self.labels_lod) "Label": (labels, self.labels_lod)
} }
self.outputs = {"Loss": loss} self.outputs = {"Loss": loss}
self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times} self.attrs = {
"blank": self.blank,
"norm_by_times": self.norm_by_times,
"use_cudnn": self.use_cudnn
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -233,6 +238,22 @@ class TestWarpCTCOpCase1(TestWarpCTCOp): ...@@ -233,6 +238,22 @@ class TestWarpCTCOpCase1(TestWarpCTCOp):
self.labels_lod = [[3, 1, 4, 4]] self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0 self.blank = 0
self.norm_by_times = False self.norm_by_times = False
self.use_cudnn = False
class TestCudnnCTCOp(TestWarpCTCOp):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0
self.norm_by_times = False
self.use_cudnn = True
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册