提交 1f2afccf 编写于 作者: S SunGaofeng 提交者: liuwei1031

test=develop (#16783)

上级 d105c06b
...@@ -121,9 +121,11 @@ class AffineGridOpKernel : public framework::OpKernel<T> { ...@@ -121,9 +121,11 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
// TODO(wanghaoshuang): Refine batched matrix multiply // TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3}); Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3}); Tensor sliced_theta = theta->Slice(i, i + 1).Resize({2, 3});
Tensor sliced_out = output->Slice(i, i + 1).Resize({h * w, 2}); Tensor sliced_out = output->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out, blas.MatMul(sliced_grid, false, sliced_theta, true, T(1), &sliced_out,
T(0)); T(0));
} }
...@@ -161,8 +163,10 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> { ...@@ -161,8 +163,10 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
// TODO(wanghaoshuang): Refine batched matrix multiply // TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
Tensor sliced_grid = grid.Slice(i, i + 1).Resize({h * w, 3}); Tensor sliced_grid = grid.Slice(i, i + 1).Resize(
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize({h * w, 2}); {static_cast<int64_t>(h) * static_cast<int64_t>(w), 3});
Tensor sliced_out_grad = output_grad->Slice(i, i + 1).Resize(
{static_cast<int64_t>(h) * static_cast<int64_t>(w), 2});
Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3}); Tensor sliced_theta_grad = theta_grad->Slice(i, i + 1).Resize({2, 3});
blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1), blas.MatMul(sliced_out_grad, true, sliced_grid, false, T(1),
&sliced_theta_grad, T(0)); &sliced_theta_grad, T(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册