提交 c42e2049 编写于 作者: H hedaoyuan

Refine code.

上级 bb546cf1
...@@ -84,8 +84,8 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -84,8 +84,8 @@ class GemmConv2DKernel : public framework::OpKernel {
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
// im2col // im2col
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], im2col(context.device_context(), in_slice, col, strides[0], strides[1],
context.device_context()); paddings[0], paddings[1]);
// gemm // gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
...@@ -185,8 +185,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -185,8 +185,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
// col2im // col2im
Tensor in_grad_slice = Tensor in_grad_slice =
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step); in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], col2im(context.device_context(), in_grad_slice, col, strides[0],
paddings[1], context.device_context()); strides[1], paddings[0], paddings[1]);
} }
} }
} }
...@@ -207,8 +207,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -207,8 +207,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
Tensor out_grad_slice = Tensor out_grad_slice =
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step); out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0], im2col(context.device_context(), in_slice, col, strides[0],
paddings[1], context.device_context()); strides[1], paddings[0], paddings[1]);
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
......
...@@ -27,9 +27,10 @@ template <class T> ...@@ -27,9 +27,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, const platform::DeviceContext& context) { int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -79,9 +80,9 @@ template <class T> ...@@ -79,9 +80,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
int stride_height, int stride_width, int padding_height, const framework::Tensor& col, int stride_height,
int padding_width, const platform::DeviceContext& context) { int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -137,9 +138,10 @@ template <class T> ...@@ -137,9 +138,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, const platform::DeviceContext& context) { int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -197,9 +199,9 @@ template <class T> ...@@ -197,9 +199,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
int stride_height, int stride_width, int padding_height, const framework::Tensor& col, int stride_height,
int padding_width, const platform::DeviceContext& context) { int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
......
...@@ -64,9 +64,10 @@ template <class T> ...@@ -64,9 +64,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, const platform::DeviceContext& context) { int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -149,9 +150,9 @@ template <class T> ...@@ -149,9 +150,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
int stride_height, int stride_width, int padding_height, const framework::Tensor& col, int stride_height,
int padding_width, const platform::DeviceContext& context) { int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -235,9 +236,10 @@ template <class T> ...@@ -235,9 +236,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, const platform::DeviceContext& context) { int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -318,9 +320,9 @@ template <class T> ...@@ -318,9 +320,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
int stride_height, int stride_width, int padding_height, const framework::Tensor& col, int stride_height,
int padding_width, const platform::DeviceContext& context) { int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
......
...@@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, const platform::DeviceContext& context); int padding_width);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
int stride_height, int stride_width, int padding_height, const framework::Tensor& col, int stride_height,
int padding_width, const platform::DeviceContext& context); int stride_width, int padding_height, int padding_width);
}; };
} // namespace math } // namespace math
......
...@@ -78,8 +78,8 @@ void testIm2col() { ...@@ -78,8 +78,8 @@ void testIm2col() {
PADDLE_THROW("no GPU support"); PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} }
im2col(input, output_cfo, stride, stride, padding, padding, *context); im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, *context); im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册