未验证 提交 128adf53 编写于 作者: D dzhwinter 提交者: GitHub

[Speed]implement cudnn sequence softmax cudnn (#8978)

* "add softmax cudnn functor support"

* "add testing"

* "refine cmakelist"

* "sequence softmax forward speed up"

* "add softmax grad"

* "fix sequence softmax test"

* "add double precision'

* "fix softmax test"

* "add softmax cudnn support"

* "fix softmax cudnn test"

* "add softmax to nn.py"

* "fix compile bug"

* "refine cmakelist"

* "fix ci"

* "fix based on comment"

* "fix based on comments"

* "fix ci"
上级 9b9f3f09
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_incubate/lite delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt incubate/lite inplace_addto make_flag_adding_easier master move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paddle_tiny_install paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/0.12.0 release/0.13.0 release/0.14.0 release/0.15.0 release/1.0.0 release/1.1 release/1.2 release/1.3 release/1.4 release/1.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/lite-0.1 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0 v1.5.2 v1.5.1 v1.5.0 v1.4.1 v1.4.0 v1.3.2 v1.3.1 v1.3.0 v1.2.1 v1.2.0 v1.1.0 v1.0.2 v1.0.1 v1.0.0 v1.0.0-rc0 v0.15.0 v0.15.0-rc0 v0.14.0 v0.13.0 v0.12.0 lite-v0.1
无相关合并请求
......@@ -14,13 +14,86 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
void SoftmaxCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor xDesc;
ScopedTensorDescriptor yDesc;
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(X->dims());
DataLayout layout = DataLayout::kNCHW;
if (cudnn_tensor_dims.size() == 5) {
layout = DataLayout::kNCDHW;
}
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
// fill 1 at unused dims
if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1);
}
cudnnTensorDescriptor_t cudnn_x_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_y_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxForward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
Y->mutable_data<T>(context.GetPlace())));
}
template <typename T>
void SoftmaxGradCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* YGrad, framework::Tensor* XGrad) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor yDesc;
ScopedTensorDescriptor dyDesc;
ScopedTensorDescriptor dxDesc;
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(Y->dims());
DataLayout layout = DataLayout::kNCHW;
if (cudnn_tensor_dims.size() == 5) {
layout = DataLayout::kNCDHW;
}
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
// fill 1 at unused dims
if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1);
}
cudnnTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_xgrad_desc =
dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_ygrad_desc =
dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxBackward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_y_desc,
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
XGrad->mutable_data<T>(context.GetPlace())));
}
template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
......
......@@ -33,6 +33,23 @@ class SoftmaxGradFunctor {
const framework::Tensor* y_grad, framework::Tensor* x_grad);
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class SoftmaxCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y);
};
template <typename T>
class SoftmaxGradCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* Y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T>
class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
"The first dimension of Input(X) should be equal to the "
"sum of all sequences' lengths.");
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor x_i = x->Slice(start_pos, end_pos);
Tensor out_i = out->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i =
// framework::make_ddim({1UL, end_pos - start_pos, 1UL, 1UL});
framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxCUDNNFunctor<T>()(
ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
&out_i);
}
}
};
template <typename T>
class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradCUDNNFunctor<T>()(
ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
&out_grad_i, &x_grad_i);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace,
ops::SequenceSoftmaxCUDNNKernel<float>,
ops::SequenceSoftmaxCUDNNKernel<double>)
REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::SequenceSoftmaxGradCUDNNKernel<float>,
ops::SequenceSoftmaxGradCUDNNKernel<double>)
......@@ -29,6 +29,29 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -41,6 +64,17 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out",
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
"of length 1.");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC(
Sequence Softmax Operator.
......@@ -91,6 +125,29 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
} // namespace operators
......@@ -102,7 +159,9 @@ REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
ops::SequenceSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -17,7 +17,10 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>)
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>)
REGISTER_OP_CUDA_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), X, Out);
}
};
template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), Out,
dOut, dX);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>);
......@@ -33,6 +33,29 @@ class SoftmaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -43,6 +66,17 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC(
Softmax Operator.
......@@ -80,6 +114,29 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
} // namespace operators
......
......@@ -289,7 +289,7 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
......
......@@ -132,7 +132,7 @@ def detection_output(loc,
old_shape = scores.shape
scores = ops.reshape(x=scores, shape=(-1, old_shape[-1]))
scores = ops.softmax(x=scores)
scores = nn.softmax(input=scores)
scores = ops.reshape(x=scores, shape=old_shape)
scores = nn.transpose(scores, perm=[0, 2, 1])
......
......@@ -39,6 +39,8 @@ __all__ = [
'sequence_conv',
'conv2d',
'sequence_pool',
'sequence_softmax',
'softmax',
'pool2d',
'batch_norm',
'beam_search_decode',
......@@ -1085,6 +1087,30 @@ def sequence_conv(input,
return helper.append_activation(pre_act)
def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
helper = LayerHelper('sequence_softmax', **locals())
dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="sequence_softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={"use_cudnn": use_cudnn})
return softmax_out
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
helper = LayerHelper('softmax', **locals())
dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={"use_cudnn": use_cudnn})
return softmax_out
def conv2d(input,
num_filters,
filter_size,
......
......@@ -45,13 +45,30 @@ __activations__ = [
]
__all__ = [
'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits',
'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip',
'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or',
'logical_xor', 'logical_not', 'uniform_random',
'uniform_random_batch_size_like', 'gaussian_random',
'gaussian_random_batch_size_like', 'cumsum', 'scatter'
'mean',
'mul',
'reshape',
'scale',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'clip',
'clip_by_norm',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum',
'scatter',
] + __activations__
for _OP in set(__all__):
......
......@@ -220,7 +220,7 @@ class TestBook(unittest.TestCase):
seq_data = layers.data(
name='seq_data', shape=[10, 10], dtype='float32', lod_level=1)
seq = layers.fc(input=seq_data, size=20)
self.assertIsNotNone(layers.sequence_softmax(x=seq))
self.assertIsNotNone(layers.sequence_softmax(seq))
print(str(program))
def test_softmax(self):
......@@ -228,7 +228,7 @@ class TestBook(unittest.TestCase):
with program_guard(program):
data = layers.data(name='data', shape=[10], dtype='float32')
hid = layers.fc(input=data, size=20)
self.assertIsNotNone(layers.softmax(x=hid))
self.assertIsNotNone(layers.softmax(hid))
print(str(program))
def test_get_places(self):
......
......@@ -16,11 +16,15 @@ import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
import paddle.fluid.core as core
class TestSequenceSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "sequence_softmax"
self.use_cudnn = False
self.init_op_type()
x = np.random.uniform(0.1, 1, (11, 1)).astype("float32")
lod = [[0, 4, 5, 8, 11]]
......@@ -34,12 +38,31 @@ class TestSequenceSoftmaxOp(OpTest):
self.inputs = {"X": (x, lod)}
self.outputs = {"Out": out}
self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self):
pass
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", max_relative_error=0.01)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.01)
else:
self.check_grad(["X"], "Out", max_relative_error=0.01)
# ----------------cudnn Sequencesoftmax----------------
class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp):
def init_op_type(self):
self.use_cudnn = True
if __name__ == "__main__":
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def stable_softmax(x):
......@@ -27,18 +28,37 @@ def stable_softmax(x):
class TestSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self):
pass
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.01)
else:
self.check_grad(["X"], "Out", max_relative_error=0.01)
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_op_type(self):
self.use_cudnn = True
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部