diff --git a/paddle/fluid/operators/bilinear_interp_op.cu b/paddle/fluid/operators/bilinear_interp_op.cu index c4abdbd3b598de022bede2c5bf54cf3bea636f36..f5899e9067105c7637448530b4446df8f78d7da1 100644 --- a/paddle/fluid/operators/bilinear_interp_op.cu +++ b/paddle/fluid/operators/bilinear_interp_op.cu @@ -11,6 +11,8 @@ #include "paddle/fluid/operators/bilinear_interp_op.cu.h" #include "paddle/fluid/operators/bilinear_interp_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel { auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, d_input_t, static_cast(0.0)); + int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); int batch_size = d_input_t->dims()[0]; diff --git a/paddle/fluid/operators/bilinear_interp_op.h b/paddle/fluid/operators/bilinear_interp_op.h index 8dbc7a7485d887b706ab0f3e2b71fbec26f8eed6..f6cd77e4d49b53ecde6a84908cdffc7e1e02ac6a 100644 --- a/paddle/fluid/operators/bilinear_interp_op.h +++ b/paddle/fluid/operators/bilinear_interp_op.h @@ -10,16 +10,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; template class BilinearInterpKernel : public framework::OpKernel { @@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel { auto* d_input = d_input_t->mutable_data(ctx.GetPlace()); auto* d_output = d_output_t->data(); + auto& device_ctx = + ctx.template device_context(); + math::SetConstant zero; + zero(device_ctx, d_input_t, static_cast(0.0)); + int out_h = ctx.Attr("out_h"); int out_w = ctx.Attr("out_w"); int batch_size = d_input_t->dims()[0];