提交 6961a94e 编写于 作者: T tink2123

avoid out_size less than 1

test=develop
上级 e7eb08fe
...@@ -220,12 +220,17 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -220,12 +220,17 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
int in_chw = c * in_hw; int in_chw = c * in_hw;
int out_chw = c * out_hw; int out_chw = c * out_hw;
float ratio_h = (align_corners && out_h > 1) float ratio_h = 0.f;
? static_cast<float>(in_h - 1) / (out_h - 1) float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h; : static_cast<float>(in_h) / out_h;
float ratio_w = (align_corners && out_w > 1) }
if (out_w > 1) {
ratio_w = (align_corners && out_w > 1)
? static_cast<float>(in_w - 1) / (out_w - 1) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w; : static_cast<float>(in_w) / out_w;
}
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output); framework::TensorCopy(*input, ctx.GetPlace(), output);
...@@ -290,12 +295,17 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -290,12 +295,17 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
int in_chw = c * in_hw; int in_chw = c * in_hw;
int out_chw = c * out_hw; int out_chw = c * out_hw;
float ratio_h = (align_corners && out_h > 1) float ratio_h = 0.f;
? static_cast<float>(in_h - 1) / (out_h - 1) float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h; : static_cast<float>(in_h) / out_h;
float ratio_w = (align_corners && out_w > 1) }
if (out_w > 1) {
ratio_w = (align_corners && out_w > 1)
? static_cast<float>(in_w - 1) / (out_w - 1) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w; : static_cast<float>(in_w) / out_w;
}
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
......
...@@ -191,12 +191,18 @@ class InterpolateKernel : public framework::OpKernel<T> { ...@@ -191,12 +191,18 @@ class InterpolateKernel : public framework::OpKernel<T> {
return; return;
} }
float ratio_h = (align_corners && out_h > 1) float ratio_h = 0.f;
? static_cast<float>(in_h - 1) / (out_h - 1) float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h; : static_cast<float>(in_h) / out_h;
float ratio_w = (align_corners && out_w > 1) }
if (out_w > 1) {
ratio_w = (align_corners && out_w > 1)
? static_cast<float>(in_w - 1) / (out_w - 1) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w; : static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) { if ("bilinear" == interp_method) {
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n, BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n,
...@@ -244,12 +250,18 @@ class InterpolateGradKernel : public framework::OpKernel<T> { ...@@ -244,12 +250,18 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
return; return;
} }
float ratio_h = (align_corners && out_h > 1) float ratio_h = 0.f;
? static_cast<float>(in_h - 1) / (out_h - 1) float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h; : static_cast<float>(in_h) / out_h;
float ratio_w = (align_corners && out_w > 1) }
if (out_w > 1) {
ratio_w = (align_corners && out_w > 1)
? static_cast<float>(in_w - 1) / (out_w - 1) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w; : static_cast<float>(in_w) / out_w;
}
if ("bilinear" == interp_method) { if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w, BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w,
......
...@@ -37,11 +37,13 @@ def bilinear_interp_np(input, ...@@ -37,11 +37,13 @@ def bilinear_interp_np(input,
batch_size, channel, in_h, in_w = input.shape batch_size, channel, in_h, in_w = input.shape
ratio_h = ratio_w = 0.0 ratio_h = ratio_w = 0.0
if (align_corners and out_h > 1): if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0) ratio_h = (in_h - 1.0) / (out_h - 1.0)
else: else:
ratio_h = 1.0 * in_h / out_h ratio_h = 1.0 * in_h / out_h
if (align_corners and out_w > 1): if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0) ratio_w = (in_w - 1.0) / (out_w - 1.0)
else: else:
ratio_w = 1.0 * in_w / out_w ratio_w = 1.0 * in_w / out_w
......
...@@ -36,11 +36,13 @@ def nearest_neighbor_interp_np(X, ...@@ -36,11 +36,13 @@ def nearest_neighbor_interp_np(X,
n, c, in_h, in_w = X.shape n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0 ratio_h = ratio_w = 0.0
if (align_corners and out_h > 1): if (out_h > 1):
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0) ratio_h = (in_h - 1.0) / (out_h - 1.0)
else: else:
ratio_h = 1.0 * in_h / out_h ratio_h = 1.0 * in_h / out_h
if (align_corners and out_w > 1): if (out_w > 1):
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0) ratio_w = (in_w - 1.0) / (out_w - 1.0)
else: else:
ratio_w = 1.0 * in_w / out_w ratio_w = 1.0 * in_w / out_w
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册