未验证 提交 a9c3e32d 编写于 作者: Z zyfncg 提交者: GitHub

fix bug of test_pad_op for cinn (#53772)

上级 3d6bd6a4
......@@ -1949,15 +1949,15 @@ void pad_grad(const Tensor& input,
size_t rank = input.dims().size();
auto out_dims = out_grad.dims();
std::vector<int> starts(rank, 0);
std::vector<int64_t> starts(rank, 0);
std::vector<int64_t> ends(rank, 0);
std::vector<int64_t> axes(rank, 0);
std::vector<int64_t> infer_flags(rank, 1);
std::vector<int64_t> decrease_axis({});
for (size_t i = 0; i < rank; ++i) {
starts.push_back(static_cast<int>(paddings[2 * i]));
ends.push_back(static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]));
axes.push_back(i);
starts[i] = static_cast<int64_t>(paddings[2 * i]);
ends[i] = static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]);
axes[i] = i;
}
auto out_tmp =
slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis);
......
......@@ -51,7 +51,6 @@ class TestPadOp(OpTest):
}
self.prim_op_type = "prim"
self.public_python_api = pad_wrapper
self.enable_cinn = False
def get_dtype(self):
return np.float64
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册