提交 9725db0d 编写于 作者: M minqiyang

Fix copy wrong pos bug

test=develop
上级 9c687090
......@@ -242,7 +242,7 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i];
int64_t in_offset = lod[i] * in_w;
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
for (int r = 0; r != h; ++r) {
......
......@@ -70,7 +70,7 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册