提交 5d33481c 编写于 作者: F fengjiayi

Add bilinear interp supporting for uint8

上级 a29cb4be
...@@ -110,6 +110,8 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp, ...@@ -110,6 +110,8 @@ REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
ops::BilinearInterpOpMaker, ops::BilinearInterpOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad); REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>); REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>); ops::BilinearInterpGradKernel<float>,
ops::BilinearInterpGradKernel<uint8_t>);
...@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -46,8 +46,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw; int in_chw = channels * in_hw;
int out_chw = channels * out_hw; int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; float ratio_h =
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T)); memcpy(output, input, input_t->numel() * sizeof(T));
...@@ -56,14 +58,14 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -56,14 +58,14 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i; int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0; int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h; float h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda; float h2lambda = 1.f - h1lambda;
for (int j = 0; j < out_w; ++j) { for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j; int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0; int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w; float w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda; float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation // calculate four position for bilinear interpolation
const T* in_pos = &input[k * in_chw + h * in_w + w]; const T* in_pos = &input[k * in_chw + h * in_w + w];
T* out_pos = &output[k * out_chw + i * out_w + j]; T* out_pos = &output[k * out_chw + i * out_w + j];
...@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -117,8 +119,10 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
int in_chw = channels * in_hw; int in_chw = channels * in_hw;
int out_chw = channels * out_hw; int out_chw = channels * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f; float ratio_h =
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f; (out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) { if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T)); memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
...@@ -127,14 +131,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -127,14 +131,14 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
for (int i = 0; i < out_h; ++i) { // loop for images for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i; int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0; int hid = (h < in_h - 1) ? 1 : 0;
T h1lambda = ratio_h * i - h; float h1lambda = ratio_h * i - h;
T h2lambda = 1 - h1lambda; float h2lambda = 1 - h1lambda;
for (int j = 0; j < out_w; ++j) { for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j; int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0; int wid = (w < in_w - 1) ? 1 : 0;
T w1lambda = ratio_w * j - w; float w1lambda = ratio_w * j - w;
T w2lambda = 1 - w1lambda; float w2lambda = 1 - w1lambda;
T* in_pos = &d_input[k * in_chw + h * in_w + w]; T* in_pos = &d_input[k * in_chw + h * in_w + w];
const T* out_pos = &d_output[k * out_chw + i * out_w + j]; const T* out_pos = &d_output[k * out_chw + i * out_w + j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册