未验证 提交 a9c3e32d 编写于 作者: Z zyfncg 提交者: GitHub

fix bug of test_pad_op for cinn (#53772)

上级 3d6bd6a4
...@@ -1949,15 +1949,15 @@ void pad_grad(const Tensor& input, ...@@ -1949,15 +1949,15 @@ void pad_grad(const Tensor& input,
size_t rank = input.dims().size(); size_t rank = input.dims().size();
auto out_dims = out_grad.dims(); auto out_dims = out_grad.dims();
std::vector<int> starts(rank, 0); std::vector<int64_t> starts(rank, 0);
std::vector<int64_t> ends(rank, 0); std::vector<int64_t> ends(rank, 0);
std::vector<int64_t> axes(rank, 0); std::vector<int64_t> axes(rank, 0);
std::vector<int64_t> infer_flags(rank, 1); std::vector<int64_t> infer_flags(rank, 1);
std::vector<int64_t> decrease_axis({}); std::vector<int64_t> decrease_axis({});
for (size_t i = 0; i < rank; ++i) { for (size_t i = 0; i < rank; ++i) {
starts.push_back(static_cast<int>(paddings[2 * i])); starts[i] = static_cast<int64_t>(paddings[2 * i]);
ends.push_back(static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1])); ends[i] = static_cast<int64_t>(out_dims[i] - paddings[2 * i + 1]);
axes.push_back(i); axes[i] = i;
} }
auto out_tmp = auto out_tmp =
slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis); slice<T>(out_grad, axes, starts, ends, infer_flags, decrease_axis);
......
...@@ -51,7 +51,6 @@ class TestPadOp(OpTest): ...@@ -51,7 +51,6 @@ class TestPadOp(OpTest):
} }
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.public_python_api = pad_wrapper self.public_python_api = pad_wrapper
self.enable_cinn = False
def get_dtype(self): def get_dtype(self):
return np.float64 return np.float64
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册