提交 d87ac4de 编写于 作者: W wangyang59

GPU of bilinear_interp_op done

上级 ad3b3d9d
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "paddle/fluid/operators/bilinear_interp_op.cu.h" #include "paddle/fluid/operators/bilinear_interp_op.cu.h"
#include "paddle/fluid/operators/bilinear_interp_op.h" #include "paddle/fluid/operators/bilinear_interp_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -64,6 +66,11 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace()); auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto* d_output = d_output_t->data<T>(); auto* d_output = d_output_t->data<T>();
auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
int batch_size = d_input_t->dims()[0]; int batch_size = d_input_t->dims()[0];
......
...@@ -10,16 +10,13 @@ ...@@ -10,16 +10,13 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T> template <typename T>
class BilinearInterpKernel : public framework::OpKernel<T> { class BilinearInterpKernel : public framework::OpKernel<T> {
...@@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -89,6 +86,11 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace()); auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto* d_output = d_output_t->data<T>(); auto* d_output = d_output_t->data<T>();
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
int batch_size = d_input_t->dims()[0]; int batch_size = d_input_t->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册