提交 9430bc32 编写于 作者: Z zchen0211

fix all bugs

上级 03d0040c
......@@ -75,12 +75,12 @@ void ScatterUpdate(const platform::Place& place,
auto dst_dims = output->dims();
// check src shape and dst shape should match
for (size_t i = 1; i < src_dims.size(); i++)
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
// slice size
size_t slice_size = 1;
for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (platform::is_cpu_place(place)) {
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册