From 09adb769037b34fbe8a50fd48bc3284f13456f3a Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 17 Jan 2018 11:15:54 +0800 Subject: [PATCH] Fix code style --- paddle/operators/block_expand_op.cc | 21 ++++++++++----------- paddle/operators/block_expand_op.cu | 9 +++++---- paddle/operators/block_expand_op.h | 17 ++++++++++------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/paddle/operators/block_expand_op.cc b/paddle/operators/block_expand_op.cc index bef82183b8c..f9b75ffee70 100644 --- a/paddle/operators/block_expand_op.cc +++ b/paddle/operators/block_expand_op.cc @@ -57,16 +57,14 @@ class BlockExpandOp : public framework::OperatorWithKernel { class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { public: - BlockExpandOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) + BlockExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", R"DOC( -(Tensor)The input tensor has NCHW format. - N: batch size - C: channels - H: height - W: width -)DOC"); + AddInput("X", + "(Tensor)The input tensor has NCHW format." + "N: batch size" + "C: channels" + "H: height" + "W: width"); AddOutput("Out", "(LodTensor)The output data of block_expand op,"); AddAttr("block_height", "(int)height of block."); AddAttr("block_width", "(int)width of block."); @@ -155,7 +153,8 @@ namespace ops = paddle::operators; REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker, block_expand_grad, ops::BlockExpandGradOp); REGISTER_OP_CPU_KERNEL( - block_expand, ops::BlockExpandKernel); + block_expand, + ops::BlockExpandKernel); REGISTER_OP_CPU_KERNEL( block_expand_grad, - ops::BlockExpandGradKernel); + ops::BlockExpandGradKernel); diff --git a/paddle/operators/block_expand_op.cu b/paddle/operators/block_expand_op.cu index 492ac0c9b2e..c17b1138076 100644 --- a/paddle/operators/block_expand_op.cu +++ b/paddle/operators/block_expand_op.cu @@ -17,8 +17,9 @@ namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - block_expand, ops::BlockExpandKernel); -REGISTER_OP_GPU_KERNEL( +REGISTER_OP_CUDA_KERNEL( + block_expand, + ops::BlockExpandKernel); +REGISTER_OP_CUDA_KERNEL( block_expand_grad, - ops::BlockExpandGradKernel); + ops::BlockExpandGradKernel); diff --git a/paddle/operators/block_expand_op.h b/paddle/operators/block_expand_op.h index 2e4f0cb6f1d..72760fb23c1 100644 --- a/paddle/operators/block_expand_op.h +++ b/paddle/operators/block_expand_op.h @@ -31,7 +31,7 @@ inline int get_output_size(int img_size, int block_size, int stride, return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride); } -template +template class BlockExpandKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -71,8 +71,9 @@ class BlockExpandKernel : public framework::OpKernel { img_channels, block_height, block_width}); - math::Im2ColFunctor f; - f(ctx.device_context(), src, dilations, strides, paddings, &dst); + math::Im2ColFunctor f; + auto& dev_ctx = ctx.template device_context(); + f(dev_ctx, src, dilations, strides, paddings, &dst); } out->Resize(out_dims); @@ -87,7 +88,7 @@ class BlockExpandKernel : public framework::OpKernel { } }; -template +template class BlockExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -98,7 +99,8 @@ class BlockExpandGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); auto x_v = framework::EigenVector::Flatten(*d_x); - x_v.device(ctx.GetEigenDevice()) = x_v.constant(0.0); + auto& place = *ctx.template device_context().eigen_device(); + x_v.device(place) = x_v.constant(0.0); auto in_dim = in->dims(); int batch_size = in_dim[0]; @@ -131,8 +133,9 @@ class BlockExpandGradKernel : public framework::OpKernel { const Tensor src = d_out->Slice(i, i + 1).Resize( {output_height, output_width, img_channels, block_height, block_width}); - math::Col2ImFunctor f; - f(ctx.device_context(), src, dilations, strides, paddings, &dst); + math::Col2ImFunctor f; + auto& dev_ctx = ctx.template device_context(); + f(dev_ctx, src, dilations, strides, paddings, &dst); } d_out->Resize(d_out_dims); } -- GitLab