提交 34aac18c 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3 from reyoung/pr/4929

Several Enhancement
...@@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -68,7 +68,7 @@ class LSTMOp : public framework::OperatorWithKernel {
} else { } else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection", "4 * %d if disable peepholes connection",
frame_size); frame_size);
} }
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
...@@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +86,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input", AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support " "(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 4D), where, T is the " "this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"total time steps in this mini-batch, D is the hidden size."); "total time steps in this mini-batch, D is the hidden size.");
AddInput("H0", AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional " "(Tensor, optional) the initial hidden state is an optional "
...@@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,7 +112,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
AddOutput("BatchGate", AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate " "(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate aftern the nonlinear computation. This " "and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which " "LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first " "was also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the " "LoD is the batch offsets and the second LoD contains the "
...@@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -135,18 +135,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"gateActivation", "gateActivation",
"(string, defalut: sigmoid)" "(string, default: sigmoid)"
"The activation for input gate, forget gate and output " "The activation for input gate, forget gate and output "
"gate, `sigmoid` by defalut.") "gate, `sigmoid` by default.")
.SetDefault("sigmoid"); .SetDefault("sigmoid");
AddAttr<std::string>("cellActivation", AddAttr<std::string>("cellActivation",
"(string, defalut: tanh)" "(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.") "The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh"); .SetDefault("tanh");
AddAttr<std::string>("candidateActivation", AddAttr<std::string>("candidateActivation",
"(string, defalut: tanh)" "(string, default: tanh)"
"The activation for candidate hidden state, " "The activation for candidate hidden state, "
"`tanh` by defalut.") "`tanh` by default.")
.SetDefault("tanh"); .SetDefault("tanh");
AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator
......
...@@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -52,7 +52,7 @@ class LSTMKernel : public framework::OpKernel<T> {
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
auto in_dims = input->dims(); auto in_dims = input->dims();
int frame_size = in_dims[1] / 4; int frame_size = static_cast<int>(in_dims[1] / 4);
framework::DDim dims({in_dims[0], frame_size}); framework::DDim dims({in_dims[0], frame_size});
if (bias) { if (bias) {
...@@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class LSTMKernel : public framework::OpKernel<T> {
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code styple in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.checkOg = lstm_value.checkFg + frame_size;
...@@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -83,15 +83,15 @@ class LSTMKernel : public framework::OpKernel<T> {
framework::LoDTensor batch_cell_pre_act; framework::LoDTensor batch_cell_pre_act;
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace()); batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
auto batch_lod = batch_gate->lod()[0]; auto& batch_starts = batch_gate->lod()[0];
int num_batch = batch_lod.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation"); auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation"); auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation"); auto cand_act = ctx.Attr<std::string>("candidateActivation");
for (int n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = batch_lod[n]; int bstart = static_cast<int>(batch_starts[n]);
int bend = batch_lod[n + 1]; int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice<T>(bstart, bend); Tensor gate_t = batch_gate->Slice<T>(bstart, bend);
Tensor out_t = batch_out.Slice<T>(bstart, bend); Tensor out_t = batch_out.Slice<T>(bstart, bend);
...@@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -101,14 +101,14 @@ class LSTMKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
if (n != 0) { if (n != 0) {
int pre_h_start = batch_lod[n - 1]; int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end); auto pre_hidden_t = batch_out.Slice<T>(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false, math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t, *weight, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
// else if : support the initial hidden and cell // else if : FIXME support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>(); lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>(); lstm_value.outputValue = out_t.data<T>();
......
...@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#ifdef __CUDA_ARCH__ #include <type_traits>
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,7 +27,7 @@ namespace forward { ...@@ -30,7 +27,7 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output, T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput, typename hppl::ForwardActType<T>::type actInput,
...@@ -45,11 +42,13 @@ class lstm { ...@@ -45,11 +42,13 @@ class lstm {
output = valueOg * stateAtv; output = valueOg * stateAtv;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; // Only float support AVX optimization
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state, __m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI, __m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkF, __m256 &checkO,
...@@ -76,11 +75,12 @@ namespace backward { ...@@ -76,11 +75,12 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg, T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &prevState, T &prevStateGrad, T &state,
T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &stateGrad, T &stateAtv, T &outputGrad,
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput, typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate, typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) { typename hppl::BackwardActType<T>::type actState) {
...@@ -95,18 +95,19 @@ class lstm { ...@@ -95,18 +95,19 @@ class lstm {
checkOGrad = gradOg * state; checkOGrad = gradOg * state;
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; // Only float support AVX optimization
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg, __m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState, __m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state, __m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv, __m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF, __m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkOGrad, __m256 &checkFGrad, __m256 &checkOGrad,
hppl::Active<__m256>::backward actInput, hppl::Active<__m256>::backward actInput,
hppl::Active<__m256>::backward actGate, hppl::Active<__m256>::backward actGate,
hppl::Active<__m256>::backward actState) { hppl::Active<__m256>::backward actState) {
......
...@@ -24,8 +24,8 @@ template <class T> ...@@ -24,8 +24,8 @@ template <class T>
struct LstmUnitFunctor<platform::CPUPlace, T> { struct LstmUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context, static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act, const std::string& gate_act, const std::string& cell_act,
std::string cand_act) { const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act), ActiveType(cand_act), ActiveType(gate_act),
...@@ -45,8 +45,9 @@ template <class T> ...@@ -45,8 +45,9 @@ template <class T>
struct LstmUnitGradFunctor<platform::CPUPlace, T> { struct LstmUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext& context, static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act, int frame_size, int batch_size,
std::string cell_act, std::string cand_act) { const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act), frame_size, ActiveType(cand_act),
......
...@@ -24,8 +24,8 @@ template <class T> ...@@ -24,8 +24,8 @@ template <class T>
struct LstmUnitFunctor<platform::GPUPlace, T> { struct LstmUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context, static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act, const std::string& gate_act, const std::string& cell_act,
std::string cand_act) { const std::string& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act)); ActiveType(gate_act), ActiveType(cell_act));
...@@ -36,8 +36,9 @@ template <class T> ...@@ -36,8 +36,9 @@ template <class T>
struct LstmUnitGradFunctor<platform::GPUPlace, T> { struct LstmUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext& context, static void compute(const platform::DeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act, int frame_size, int batch_size,
std::string cell_act, std::string cand_act) { const std::string& gate_act, const std::string& cell_act,
const std::string& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act)); ActiveType(gate_act), ActiveType(cell_act));
......
...@@ -72,8 +72,8 @@ class LstmUnitFunctor { ...@@ -72,8 +72,8 @@ class LstmUnitFunctor {
public: public:
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act, const std::string &gate_act, const std::string &cell_act,
std::string cand_act); const std::string &cand_act);
}; };
template <typename Place, typename T> template <typename Place, typename T>
...@@ -81,8 +81,9 @@ class LstmUnitGradFunctor { ...@@ -81,8 +81,9 @@ class LstmUnitGradFunctor {
public: public:
static void compute(const platform::DeviceContext &context, static void compute(const platform::DeviceContext &context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, std::string gate_act, int frame_size, int batch_size,
std::string cell_act, std::string cand_act); const std::string &gate_act, const std::string &cell_act,
const std::string &cand_act);
}; };
} // namespace math } // namespace math
......
...@@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> { ...@@ -51,8 +51,6 @@ class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>; template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>; template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, float>; template class Batch2LoDTensorFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::CPUPlace, double>; template class Batch2LoDTensorFunctor<platform::CPUPlace, double>;
......
...@@ -21,7 +21,7 @@ namespace math { ...@@ -21,7 +21,7 @@ namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX> template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, __global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
int64_t height, int64_t width, int64_t height, int64_t width,
const bool is_src_index) { bool is_src_index) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = threadIdx.y; int idy = threadIdx.y;
int id = blockIdx.x + idy * GridDimX; int id = blockIdx.x + idy * GridDimX;
......
...@@ -31,19 +31,11 @@ class CopyMatrixRowsFunctor { ...@@ -31,19 +31,11 @@ class CopyMatrixRowsFunctor {
// The indexed rows are based on the input index. // The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index, const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, const bool is_src_index); framework::LoDTensor& dst, bool is_src_index);
}; };
template <typename Place, typename T> template <typename Place, typename T>
class LoDTensor2BatchFunctor { class LoDTensor2BatchFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, const bool is_reverse) const {
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
// Calculate the length of each sequence and // Calculate the length of each sequence and
// sort sequence index by the length. // sort sequence index by the length.
// example: sequences = {s0, s1, s2} // example: sequences = {s0, s1, s2}
...@@ -58,6 +50,14 @@ class LoDTensor2BatchFunctor { ...@@ -58,6 +50,14 @@ class LoDTensor2BatchFunctor {
int seq_idx; int seq_idx;
}; };
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
std::vector<SeqInfo> seq_info; std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id]; int length = lod[seq_id + 1] - lod[seq_id];
...@@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor { ...@@ -75,31 +75,34 @@ class LoDTensor2BatchFunctor {
// batchIndex = {b0, b1, b2, b3, b4} // batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12} // batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9, // seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10, // 5, 1, 10,
// 6, 2, 11, // 6, 2, 11,
// 7, 3, // 7, 3,
// 8} // 8}
// The batch number represents batch size after rearranging the // The batch number represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence. // input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods; paddle::framework::LoD batch_lods;
batch_lods.push_back(std::vector<size_t>{0}); batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.push_back(std::vector<size_t>{0}); batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor // batch_lods[0] is the start positions for batch LoDTensor
int num_batch = (size_t)seq_info[0].length; int num_batch = seq_info[0].length;
batch_lods[0].resize(num_batch + 1); batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor // batch_lods[1] is the raw index in the input LoDTensor
auto dims = lod_tensor.dims(); auto dims = lod_tensor.dims();
batch_lods[1].resize(dims[0]); batch_lods[1].resize(static_cast<size_t>(dims[0]));
size_t* batch_starts = batch_lods[0].data(); size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data(); size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0; batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int batch_id = batch_starts[n]; auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) { for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length; size_t seq_len = seq_info[i].length;
int start = seq_info[i].start; int start = seq_info[i].start;
...@@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor { ...@@ -114,7 +117,7 @@ class LoDTensor2BatchFunctor {
break; break;
} }
} }
batch_starts[n + 1] = batch_id; batch_starts[n + 1] = static_cast<size_t>(batch_id);
} }
batch.set_lod(batch_lods); batch.set_lod(batch_lods);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册