提交 49a59421 编写于 作者: W wanghaox

fix some typos

......@@ -90,11 +90,11 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), The output of SequenceSliceOp.");
"(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
The operator crop a subsequence from given sequence with given start offset and subsequence length.
The operator crops a subsequence from given sequence with given start offset and subsequence length.
It only supports sequence (LoD Tensor with level number is 1).
- Case:
X = [[a1, a2;
......@@ -109,7 +109,7 @@ It only supports sequence (LoD Tensor with level number is 1).
b1, b2]
[e1, e2]]
LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
NOTE: The length of the input, offset and length should be the same. The offset start from 0.
NOTE: The first dimension size of input, the size of offset and Length, should be equal. The offset start from 0.
)DOC");
}
};
......
......@@ -83,7 +83,8 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LT(
lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1],
"The target tensor's length overflow")}
"The target tensor's length overflow")
}
out->mutable_data<T>(ctx.GetPlace());
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
......@@ -140,27 +141,29 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
auto lod = in->lod();
auto out_lod = out_grad->lod();
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
if (x_grad) {
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
static_cast<int>(out_lod[0][i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]),
static_cast<int>(out_lod[0][i + 1]));
auto out_grad_stride = framework::stride(out_grad_t.dims());
auto x_grad_stride = framework::stride(x_grad->dims());
auto x_grad_stride = framework::stride(x_grad->dims());
Tensor x_grad_t = x_grad->Slice(
static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
Tensor x_grad_t = x_grad->Slice(
static_cast<int>(lod[0][i] + offset_data[i]),
static_cast<int>(lod[0][i] + offset_data[i] + length_data[i]));
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
out_grad_stride, out_grad_t.dims(), x_grad_stride,
x_grad_t.data<T>());
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册