From d87ac4de34f5ddc9e80706b1ac892e6a7ab4cbb6 Mon Sep 17 00:00:00 2001 From: wangyang59 Date: Thu, 22 Mar 2018 16:15:28 -0700 Subject: [PATCH] GPU of bilinear_interp_op done --- paddle/fluid/operators/bilinear_interp_op.cu | 7 +++++++ paddle/fluid/operators/bilinear_interp_op.h | 10 ++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu index c4abdbd3b59..f5899e90671 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ b/paddle/fluid/operators/bilinear_interp_op.cu @@ -11,6 +11,8 @@ #include "paddle/fluid/operators/bilinear_interp_op.cu.h" #include "paddle/fluid/operators/bilinear_interp_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, d_input_t, static_cast(0.0)); + int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); int batch_size = d_input_t->dims()[0]; diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h index 8dbc7a7485d..f6cd77e4d49 100644 --- a/paddle/fluid/operators/bilinear_interp_op.h +++ b/paddle/fluid/operators/bilinear_interp_op.h @@ -10,16 +10,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; template class BilinearInterpKernel : public framework::OpKernel { @@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel { auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, d_input_t, static_cast(0.0)); + int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); int batch_size = d_input_t->dims()[0]; -- GitLab