Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6757a315
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
6757a315
编写于
9月 18, 2018
作者:
C
chengduo
提交者:
GitHub
9月 18, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Accelerate] Refine seq_softmax_op (#13421)
* refine seq_softmax_op * fix seq_softmax * use cub in seq_softmax
上级
c6865954
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
261 addition
and
75 deletion
+261
-75
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc
paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc
+3
-3
paddle/fluid/operators/sequence_softmax_op.cu
paddle/fluid/operators/sequence_softmax_op.cu
+171
-0
paddle/fluid/operators/sequence_softmax_op.cu.cc
paddle/fluid/operators/sequence_softmax_op.cu.cc
+0
-26
paddle/fluid/operators/sequence_softmax_op.h
paddle/fluid/operators/sequence_softmax_op.h
+83
-42
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
6757a315
...
...
@@ -107,7 +107,7 @@ paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'use_mkldnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, False, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None,
Tru
e))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None,
Fals
e))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'use_mkldnn', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, False, None))
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
6757a315
...
...
@@ -252,12 +252,12 @@ endif()
op_library
(
cross_entropy_op DEPS cross_entropy
)
if
(
WITH_GPU
)
op_library
(
softmax_with_cross_entropy_op DEPS cross_entropy softmax cub
)
op_library
(
sequence_softmax_op DEPS cub
)
else
()
op_library
(
softmax_with_cross_entropy_op DEPS cross_entropy softmax
)
endif
()
op_library
(
softmax_op DEPS softmax
)
op_library
(
sequence_softmax_op DEPS softmax
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
op_library
(
tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(tensorrt_engine);
\n
"
)
...
...
paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc
浏览文件 @
6757a315
...
...
@@ -29,8 +29,8 @@ class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
lod
=
x
->
lod
();
auto
dims
=
x
->
dims
();
auto
&
lod
=
x
->
lod
();
auto
&
dims
=
x
->
dims
();
const
size_t
level
=
lod
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
dims
[
0
],
static_cast
<
int64_t
>
(
lod
[
level
].
back
()),
...
...
@@ -71,7 +71,7 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
if
(
x_grad
)
{
x_grad
->
set_lod
(
x
->
lod
());
}
auto
lod
=
x
->
lod
();
auto
&
lod
=
x
->
lod
();
const
size_t
level
=
lod
.
size
()
-
1
;
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/fluid/operators/sequence_softmax_op.cu
0 → 100644
浏览文件 @
6757a315
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/sequence_softmax_op.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
__device__
__forceinline__
float
real_exp
(
float
x
)
{
return
expf
(
x
);
}
__device__
__forceinline__
double
real_exp
(
double
x
)
{
return
exp
(
x
);
}
template
<
typename
T
,
int
BlockDim
>
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
>
;
template
<
typename
T
,
int
BlockDim
>
using
BlockReduceTempStorage
=
typename
BlockReduce
<
T
,
BlockDim
>::
TempStorage
;
template
<
typename
T
,
int
BlockDim
>
__global__
void
sequence_softmax_kernel
(
const
T
*
in_data
,
const
size_t
*
ref_lod
,
const
size_t
src_hight
,
T
*
out_data
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
__shared__
T
shared_max_data
;
__shared__
T
shared_sum_data
;
for
(
int
i
=
blockIdx
.
x
;
i
<
src_hight
;
i
+=
gridDim
.
x
)
{
size_t
start
=
ref_lod
[
i
];
size_t
span
=
ref_lod
[
i
+
1
]
-
start
;
// Find the max ele
T
max_ele
=
-
FLT_MAX
;
for
(
int
tid
=
threadIdx
.
x
;
tid
<
span
;
tid
+=
blockDim
.
x
)
{
T
ele
=
in_data
[
start
+
tid
];
max_ele
=
max_ele
>
ele
?
max_ele
:
ele
;
}
max_ele
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
max_ele
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
{
shared_max_data
=
max_ele
;
}
__syncthreads
();
// sum
T
sum_data
=
0
;
for
(
int
tid
=
threadIdx
.
x
;
tid
<
span
;
tid
+=
blockDim
.
x
)
{
T
ele
=
in_data
[
start
+
tid
];
sum_data
+=
real_exp
(
ele
-
shared_max_data
);
}
sum_data
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
sum_data
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
shared_sum_data
=
sum_data
;
}
__syncthreads
();
// get final resit
for
(
int
tid
=
threadIdx
.
x
;
tid
<
span
;
tid
+=
blockDim
.
x
)
{
T
ele
=
in_data
[
start
+
tid
];
ele
=
real_exp
(
ele
-
shared_max_data
)
/
shared_sum_data
;
out_data
[
start
+
tid
]
=
ele
;
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
sequence_softmax_grad_kernel
(
const
T
*
softmax_grad_data
,
const
T
*
softmax_data
,
const
size_t
*
ref_lod
,
const
size_t
src_hight
,
T
*
dx_data
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
__shared__
T
shared_data
;
for
(
int
i
=
blockIdx
.
x
;
i
<
src_hight
;
i
+=
gridDim
.
x
)
{
size_t
start
=
ref_lod
[
i
];
size_t
span
=
ref_lod
[
i
+
1
]
-
start
;
T
result
=
0
;
for
(
int
tid
=
threadIdx
.
x
;
tid
<
span
;
tid
+=
blockDim
.
x
)
{
size_t
idx
=
start
+
tid
;
T
s_g_d
=
softmax_grad_data
[
idx
];
T
s_d
=
softmax_data
[
idx
];
result
+=
s_g_d
*
s_d
;
}
result
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
result
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
shared_data
=
result
;
}
__syncthreads
();
for
(
int
tid
=
threadIdx
.
x
;
tid
<
span
;
tid
+=
blockDim
.
x
)
{
size_t
idx
=
start
+
tid
;
T
s_g_d
=
softmax_grad_data
[
idx
];
T
s_d
=
softmax_data
[
idx
];
dx_data
[
idx
]
=
(
s_g_d
-
shared_data
)
*
s_d
;
}
}
}
template
<
typename
T
>
struct
SequenceSoftmaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
LoDTensor
&
x
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*referenced lod*/
LoDTensor
*
out
)
{
int
hight
=
ref_lod
.
size
()
-
1
;
const
int
kThreadsPerBlock
=
32
;
int
thread_x
=
kThreadsPerBlock
;
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
dim3
block_size
(
thread_x
);
dim3
grid_size
(
max_blocks
);
sequence_softmax_kernel
<
T
,
kThreadsPerBlock
><<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
x
.
data
<
T
>
(),
ref_lod
.
CUDAData
(
context
.
GetPlace
()),
hight
,
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()));
}
};
template
<
typename
T
>
struct
SequenceSoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
LoDTensor
&
dout
,
const
LoDTensor
&
out
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*referenced lod*/
LoDTensor
*
dx
)
{
size_t
hight
=
ref_lod
.
size
()
-
1
;
const
int
kThreadsPerBlock
=
32
;
int
thread_x
=
kThreadsPerBlock
;
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
dim3
block_size
(
thread_x
);
dim3
grid_size
(
max_blocks
);
sequence_softmax_grad_kernel
<
T
,
kThreadsPerBlock
><<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
ref_lod
.
CUDAData
(
context
.
GetPlace
()),
hight
,
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
()));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_softmax
,
ops
::
SequenceSoftmaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceSoftmaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
sequence_softmax_grad
,
ops
::
SequenceSoftmaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceSoftmaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/sequence_softmax_op.cu.cc
已删除
100644 → 0
浏览文件 @
c6865954
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_softmax_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_softmax
,
ops
::
SequenceSoftmaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceSoftmaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
sequence_softmax_grad
,
ops
::
SequenceSoftmaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceSoftmaxGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/sequence_softmax_op.h
浏览文件 @
6757a315
...
...
@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,12 +22,76 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
DeviceContext
,
typename
T
>
struct
SequenceSoftmaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
x
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*expand referenced lod*/
LoDTensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
SequenceSoftmaxGradFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
LoDTensor
&
dout
,
const
LoDTensor
&
out
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*referenced lod*/
LoDTensor
*
dx
);
};
template
<
typename
T
>
struct
SequenceSoftmaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
LoDTensor
&
x
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*referenced lod*/
LoDTensor
*
out
)
{
size_t
hight
=
ref_lod
.
size
()
-
1
;
const
T
*
in_data
=
x
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
hight
;
++
i
)
{
size_t
span
=
ref_lod
[
i
+
1
]
-
ref_lod
[
i
];
T
result
=
0
;
for
(
size_t
j
=
0
;
j
<
span
;
++
j
)
{
result
+=
exp
(
in_data
[
ref_lod
[
i
]
+
j
]);
}
for
(
size_t
j
=
0
;
j
<
span
;
++
j
)
{
out_data
[
ref_lod
[
i
]
+
j
]
=
exp
(
in_data
[
ref_lod
[
i
]
+
j
])
/
result
;
}
}
}
};
template
<
typename
T
>
struct
SequenceSoftmaxGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
LoDTensor
&
dout
,
const
LoDTensor
&
out
,
const
framework
::
Vector
<
size_t
>
&
ref_lod
,
/*referenced lod*/
LoDTensor
*
dx
)
{
size_t
hight
=
ref_lod
.
size
()
-
1
;
const
T
*
softmax_grad_data
=
dout
.
data
<
T
>
();
const
T
*
softmax
=
out
.
data
<
T
>
();
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
hight
;
++
i
)
{
size_t
span
=
ref_lod
[
i
+
1
]
-
ref_lod
[
i
];
T
result
=
0
;
for
(
size_t
j
=
0
;
j
<
span
;
++
j
)
{
result
+=
softmax_grad_data
[
ref_lod
[
i
]
+
j
]
*
softmax
[
ref_lod
[
i
]
+
j
];
}
for
(
size_t
j
=
0
;
j
<
span
;
++
j
)
{
dx_data
[
ref_lod
[
i
]
+
j
]
=
(
softmax_grad_data
[
ref_lod
[
i
]
+
j
]
-
result
)
*
softmax
[
ref_lod
[
i
]
+
j
];
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceSoftmaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
lod
=
x
->
lod
();
auto
dims
=
x
->
dims
();
...
...
@@ -42,55 +105,33 @@ class SequenceSoftmaxKernel : public framework::OpKernel<T> {
"SequenceSoftmaxOp should be 1."
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
[
level
].
size
())
-
1
;
++
i
)
{
int
start_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
]);
int
end_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
+
1
]);
Tensor
x_i
=
x
->
Slice
(
start_pos
,
end_pos
);
Tensor
out_i
=
out
->
Slice
(
start_pos
,
end_pos
);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework
::
DDim
dims_i
=
framework
::
make_ddim
({
1UL
,
end_pos
-
start_pos
});
x_i
.
Resize
(
dims_i
);
out_i
.
Resize
(
dims_i
);
math
::
SoftmaxFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
&
x_i
,
&
out_i
);
}
SequenceSoftmaxFunctor
<
DeviceContext
,
T
>
seq_softmax_functor
;
seq_softmax_functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
*
x
,
lod
[
level
],
out
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceSoftmaxGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
out
=
ctx
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
out_grad
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
x_grad
)
{
x_grad
->
set_lod
(
x
->
lod
())
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
out
=
ctx
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
out_grad
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
!
x_grad
)
{
return
;
}
x_grad
->
set_lod
(
x
->
lod
());
auto
lod
=
x
->
lod
();
const
size_t
level
=
lod
.
size
()
-
1
;
x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
[
level
].
size
())
-
1
;
++
i
)
{
int
start_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
]);
int
end_pos
=
static_cast
<
int
>
(
lod
[
level
][
i
+
1
]);
Tensor
out_i
=
out
->
Slice
(
start_pos
,
end_pos
);
Tensor
out_grad_i
=
out_grad
->
Slice
(
start_pos
,
end_pos
);
Tensor
x_grad_i
=
x_grad
->
Slice
(
start_pos
,
end_pos
);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework
::
DDim
dims_i
=
framework
::
make_ddim
({
1UL
,
end_pos
-
start_pos
});
out_i
.
Resize
(
dims_i
);
out_grad_i
.
Resize
(
dims_i
);
x_grad_i
.
Resize
(
dims_i
);
math
::
SoftmaxGradFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
&
out_i
,
&
out_grad_i
,
&
x_grad_i
);
}
SequenceSoftmaxGradFunctor
<
DeviceContext
,
T
>
seq_softmax_grad_functor
;
seq_softmax_grad_functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
*
out_grad
,
*
out
,
lod
[
level
],
x_grad
);
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
6757a315
...
...
@@ -1275,7 +1275,7 @@ def sequence_conv(input,
return
helper
.
append_activation
(
pre_act
)
def
sequence_softmax
(
input
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
Tru
e
):
def
sequence_softmax
(
input
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
Fals
e
):
"""
This function computes the softmax activation among all time-steps for each
sequence. The dimension of each time-step should be 1. Thus, the shape of
...
...
@@ -1298,7 +1298,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
\
library is installed. Default:
Tru
e
library is installed. Default:
Fals
e
Returns:
Variable: output of sequence_softmax
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录