提交 09adb769 编写于 作者: W wanghaoshuang

Fix code style

上级 bfe7e242
......@@ -57,16 +57,14 @@ class BlockExpandOp : public framework::OperatorWithKernel {
class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BlockExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
BlockExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", R"DOC(
(Tensor)The input tensor has NCHW format.
N: batch size
C: channels
H: height
W: width
)DOC");
AddInput("X",
"(Tensor)The input tensor has NCHW format."
"N: batch size"
"C: channels"
"H: height"
"W: width");
AddOutput("Out", "(LodTensor)The output data of block_expand op,");
AddAttr<int>("block_height", "(int)height of block.");
AddAttr<int>("block_width", "(int)width of block.");
......@@ -155,7 +153,8 @@ namespace ops = paddle::operators;
REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker,
block_expand_grad, ops::BlockExpandGradOp);
REGISTER_OP_CPU_KERNEL(
block_expand, ops::BlockExpandKernel<paddle::platform::CPUPlace, float>);
block_expand,
ops::BlockExpandKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
block_expand_grad,
ops::BlockExpandGradKernel<paddle::platform::CPUPlace, float>);
ops::BlockExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -17,8 +17,9 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
block_expand, ops::BlockExpandKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
block_expand,
ops::BlockExpandKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
block_expand_grad,
ops::BlockExpandGradKernel<paddle::platform::GPUPlace, float>);
ops::BlockExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -31,7 +31,7 @@ inline int get_output_size(int img_size, int block_size, int stride,
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
}
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BlockExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -71,8 +71,9 @@ class BlockExpandKernel : public framework::OpKernel<T> {
img_channels, block_height,
block_width});
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
out->Resize(out_dims);
......@@ -87,7 +88,7 @@ class BlockExpandKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BlockExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -98,7 +99,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace());
auto x_v = framework::EigenVector<T>::Flatten(*d_x);
x_v.device(ctx.GetEigenDevice<Place>()) = x_v.constant(0.0);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
x_v.device(place) = x_v.constant(0.0);
auto in_dim = in->dims();
int batch_size = in_dim[0];
......@@ -131,8 +133,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
const Tensor src = d_out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, block_height,
block_width});
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
math::Col2ImFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
d_out->Resize(d_out_dims);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册