Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
65906ef1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65906ef1
编写于
10月 20, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Several Enhancement
上级
694bc64a
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
102 addition
and
97 deletion
+102
-97
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+8
-8
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+9
-9
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+42
-41
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+5
-4
paddle/operators/math/lstm_compute.cu
paddle/operators/math/lstm_compute.cu
+5
-4
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+5
-4
paddle/operators/math/sequence2batch.cc
paddle/operators/math/sequence2batch.cc
+0
-2
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+1
-1
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+27
-24
未找到文件。
paddle/operators/lstm_op.cc
浏览文件 @
65906ef1
...
...
@@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection"
,
"4 * %d if di
s
able peepholes connection"
,
frame_size
);
}
ctx
->
SetOutputDim
(
"Hidden"
,
{
x_dims
[
0
],
frame_size
});
...
...
@@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Input"
,
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTen
os
r is a matrix with shape (T X 4D), where, T is the "
"this LoDTen
so
r is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size."
);
AddInput
(
"H0"
,
"(Tensor, optional) the initial hidden state is an optional "
...
...
@@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."
);
AddOutput
(
"BatchGate"
,
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after
n
the nonlinear computation. This "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
...
...
@@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"gateActivation"
,
"(string, defa
lu
t: sigmoid)"
"(string, defa
ul
t: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by defa
lu
t."
)
"gate, `sigmoid` by defa
ul
t."
)
.
SetDefault
(
"sigmoid"
);
AddAttr
<
std
::
string
>
(
"cellActivation"
,
"(string, defa
lu
t: tanh)"
"(string, defa
ul
t: tanh)"
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
);
AddAttr
<
std
::
string
>
(
"candidateActivation"
,
"(string, defa
lu
t: tanh)"
"(string, defa
ul
t: tanh)"
"The activation for candidate hidden state, "
"`tanh` by defa
lu
t."
)
"`tanh` by defa
ul
t."
)
.
SetDefault
(
"tanh"
);
AddComment
(
R"DOC(Long-Short Term Memory (LSTM) Operator
...
...
paddle/operators/lstm_op.h
浏览文件 @
65906ef1
...
...
@@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> {
to_batch
(
ctx
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
auto
in_dims
=
input
->
dims
();
int
frame_size
=
in_dims
[
1
]
/
4
;
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
)
;
framework
::
DDim
dims
({
in_dims
[
0
],
frame_size
});
if
(
bias
)
{
...
...
@@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> {
math
::
LstmMetaValue
<
T
>
lstm_value
;
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code sty
p
le in LstmMetaValue will be updated later.
// the code style in LstmMetaValue will be updated later.
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
...
...
@@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> {
framework
::
LoDTensor
batch_cell_pre_act
;
batch_cell_pre_act
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
auto
batch_lod
=
batch_gate
->
lod
()[
0
];
int
num_batch
=
batch_lod
.
size
()
-
1
;
auto
&
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gateActivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cellActivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidateActivation"
);
for
(
in
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
batch_lod
[
n
]
;
int
bend
=
batch_lod
[
n
+
1
]
;
for
(
size_
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
])
;
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
])
;
Tensor
gate_t
=
batch_gate
->
Slice
<
T
>
(
bstart
,
bend
);
Tensor
out_t
=
batch_out
.
Slice
<
T
>
(
bstart
,
bend
);
...
...
@@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
!=
0
)
{
int
pre_h_start
=
batch_lod
[
n
-
1
]
;
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
])
;
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_out
.
Slice
<
T
>
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
pre_hidden_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
// else if : support the initial hidden and cell
// else if :
FIXME
support the initial hidden and cell
lstm_value
.
gateValue
=
gate_t
.
data
<
T
>
();
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
65906ef1
...
...
@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
#include <type_traits>
namespace
paddle
{
namespace
operators
{
...
...
@@ -30,7 +27,7 @@ namespace forward {
template
<
class
T
>
class
lstm
{
public:
INLIN
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
HOSTDEVIC
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
typename
hppl
::
ForwardActType
<
T
>::
type
actInput
,
...
...
@@ -45,11 +42,13 @@ class lstm {
output
=
valueOg
*
stateAtv
;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
// If not compiled with AVX instructs. Disable AVX by default
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
...
...
@@ -76,11 +75,12 @@ namespace backward {
template
<
class
T
>
class
lstm
{
public:
INLIN
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
HOSTDEVIC
E
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
gradIn
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradOg
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
typename
hppl
::
BackwardActType
<
T
>::
type
actInput
,
typename
hppl
::
BackwardActType
<
T
>::
type
actGate
,
typename
hppl
::
BackwardActType
<
T
>::
type
actState
)
{
...
...
@@ -95,18 +95,19 @@ class lstm {
checkOGrad
=
gradOg
*
state
;
}
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
// If not compiled with AVX instructs. Disable AVX by default
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
INLINE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
gradIn
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradOg
,
__m256
&
prevState
,
__m256
&
prevStateGrad
,
__m256
&
state
,
__m256
&
stateGrad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkF
Grad
,
__m256
&
checkOGrad
,
__m256
&
checkO
,
__m256
&
checkI
Grad
,
__m256
&
checkFGrad
,
__m256
&
checkOGrad
,
hppl
::
Active
<
__m256
>::
backward
actInput
,
hppl
::
Active
<
__m256
>::
backward
actGate
,
hppl
::
Active
<
__m256
>::
backward
actState
)
{
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
65906ef1
...
...
@@ -24,8 +24,8 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
...
...
@@ -45,8 +45,9 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
ActiveType
(
cand_act
),
...
...
paddle/operators/math/lstm_compute.cu
浏览文件 @
65906ef1
...
...
@@ -24,8 +24,8 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
...
...
@@ -36,8 +36,9 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
)
{
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
65906ef1
...
...
@@ -72,8 +72,8 @@ class LstmUnitFunctor {
public:
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
};
template
<
typename
Place
,
typename
T
>
...
...
@@ -81,8 +81,9 @@ class LstmUnitGradFunctor {
public:
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
std
::
string
gate_act
,
std
::
string
cell_act
,
std
::
string
cand_act
);
int
frame_size
,
int
batch_size
,
const
std
::
string
&
gate_act
,
const
std
::
string
&
cell_act
,
const
std
::
string
&
cand_act
);
};
}
// namespace math
...
...
paddle/operators/math/sequence2batch.cc
浏览文件 @
65906ef1
...
...
@@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
LoDTensor2BatchFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
Batch2LoDTensorFunctor
<
platform
::
CPUPlace
,
double
>;
...
...
paddle/operators/math/sequence2batch.cu
浏览文件 @
65906ef1
...
...
@@ -21,7 +21,7 @@ namespace math {
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
CopyMatrixRowsKernel
(
const
T
*
src
,
T
*
dst
,
const
size_t
*
index
,
int64_t
height
,
int64_t
width
,
const
bool
is_src_index
)
{
bool
is_src_index
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
threadIdx
.
y
;
int
id
=
blockIdx
.
x
+
idy
*
GridDimX
;
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
65906ef1
...
...
@@ -31,19 +31,11 @@ class CopyMatrixRowsFunctor {
// The indexed rows are based on the input index.
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
src
,
const
size_t
*
index
,
framework
::
LoDTensor
&
dst
,
const
bool
is_src_index
);
framework
::
LoDTensor
&
dst
,
bool
is_src_index
);
};
template
<
typename
Place
,
typename
T
>
class
LoDTensor2BatchFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
const
bool
is_reverse
)
const
{
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
...
...
@@ -58,6 +50,14 @@ class LoDTensor2BatchFunctor {
int
seq_idx
;
};
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
bool
is_reverse
)
const
{
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
()
-
1
;
++
seq_id
)
{
int
length
=
lod
[
seq_id
+
1
]
-
lod
[
seq_id
];
...
...
@@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor {
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
paddle
::
framework
::
LoD
batch_lods
;
batch_lods
.
push
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
push
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace
_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace
_back
(
std
::
vector
<
size_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
int
num_batch
=
(
size_t
)
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
num_batch
+
1
);
int
num_batch
=
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
static_cast
<
size_t
>
(
num_batch
+
1
)
);
// batch_lods[1] is the raw index in the input LoDTensor
auto
dims
=
lod_tensor
.
dims
();
batch_lods
[
1
].
resize
(
dims
[
0
]
);
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
dims
[
0
])
);
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
batch_starts
[
0
]
=
0
;
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
batch_id
=
batch_starts
[
n
]
;
auto
batch_id
=
static_cast
<
int
>
(
batch_starts
[
n
])
;
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
size_t
seq_len
=
seq_info
[
i
].
length
;
int
start
=
seq_info
[
i
].
start
;
...
...
@@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor {
break
;
}
}
batch_starts
[
n
+
1
]
=
batch_id
;
batch_starts
[
n
+
1
]
=
static_cast
<
size_t
>
(
batch_id
)
;
}
batch
.
set_lod
(
batch_lods
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录