提交 09adb769 编写于 作者: W wanghaoshuang

Fix code style

上级 bfe7e242
...@@ -57,16 +57,14 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -57,16 +57,14 @@ class BlockExpandOp : public framework::OperatorWithKernel {
class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BlockExpandOpMaker(framework::OpProto* proto, BlockExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", R"DOC( AddInput("X",
(Tensor)The input tensor has NCHW format. "(Tensor)The input tensor has NCHW format."
N: batch size "N: batch size"
C: channels "C: channels"
H: height "H: height"
W: width "W: width");
)DOC");
AddOutput("Out", "(LodTensor)The output data of block_expand op,"); AddOutput("Out", "(LodTensor)The output data of block_expand op,");
AddAttr<int>("block_height", "(int)height of block."); AddAttr<int>("block_height", "(int)height of block.");
AddAttr<int>("block_width", "(int)width of block."); AddAttr<int>("block_width", "(int)width of block.");
...@@ -155,7 +153,8 @@ namespace ops = paddle::operators; ...@@ -155,7 +153,8 @@ namespace ops = paddle::operators;
REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker, REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker,
block_expand_grad, ops::BlockExpandGradOp); block_expand_grad, ops::BlockExpandGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
block_expand, ops::BlockExpandKernel<paddle::platform::CPUPlace, float>); block_expand,
ops::BlockExpandKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
block_expand_grad, block_expand_grad,
ops::BlockExpandGradKernel<paddle::platform::CPUPlace, float>); ops::BlockExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
block_expand, ops::BlockExpandKernel<paddle::platform::GPUPlace, float>); block_expand,
REGISTER_OP_GPU_KERNEL( ops::BlockExpandKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
block_expand_grad, block_expand_grad,
ops::BlockExpandGradKernel<paddle::platform::GPUPlace, float>); ops::BlockExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -31,7 +31,7 @@ inline int get_output_size(int img_size, int block_size, int stride, ...@@ -31,7 +31,7 @@ inline int get_output_size(int img_size, int block_size, int stride,
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride); return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
} }
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BlockExpandKernel : public framework::OpKernel<T> { class BlockExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -71,8 +71,9 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -71,8 +71,9 @@ class BlockExpandKernel : public framework::OpKernel<T> {
img_channels, block_height, img_channels, block_height,
block_width}); block_width});
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f; math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
f(ctx.device_context(), src, dilations, strides, paddings, &dst); auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
} }
out->Resize(out_dims); out->Resize(out_dims);
...@@ -87,7 +88,7 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -87,7 +88,7 @@ class BlockExpandKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BlockExpandGradKernel : public framework::OpKernel<T> { class BlockExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -98,7 +99,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> { ...@@ -98,7 +99,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto x_v = framework::EigenVector<T>::Flatten(*d_x); auto x_v = framework::EigenVector<T>::Flatten(*d_x);
x_v.device(ctx.GetEigenDevice<Place>()) = x_v.constant(0.0); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
x_v.device(place) = x_v.constant(0.0);
auto in_dim = in->dims(); auto in_dim = in->dims();
int batch_size = in_dim[0]; int batch_size = in_dim[0];
...@@ -131,8 +133,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> { ...@@ -131,8 +133,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
const Tensor src = d_out->Slice(i, i + 1).Resize( const Tensor src = d_out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, block_height, {output_height, output_width, img_channels, block_height,
block_width}); block_width});
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f; math::Col2ImFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
f(ctx.device_context(), src, dilations, strides, paddings, &dst); auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
} }
d_out->Resize(d_out_dims); d_out->Resize(d_out_dims);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册