提交 a87953f0 编写于 作者: Y Yi Zhu 提交者: Will Zhang

Updt name && Fix bugs (#546)

* fix bugs

* rename op name

* fix kernel names

* remove useless var


Former-commit-id: d6be5427
上级 6bd244a0
#include "oneflow/core/kernel/average_pooling_kernel.h"
#include "oneflow/core/kernel/average_pooling_2d_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void AveragePoolingKernel<device_type, T>::ForwardDataContent(
void AveragePooling2DKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
Blob* out_blob = BnInOp2Blob("out");
AveragePoolingKernelUtil<device_type, T>::Forward(ctx, in_blob, out_blob,
this->pooling_ctx());
AveragePooling2DKernelUtil<device_type, T>::Forward(ctx, in_blob, out_blob,
this->pooling_2d_ctx());
}
template<DeviceType device_type, typename T>
void AveragePoolingKernel<device_type, T>::BackwardDataContent(
void AveragePooling2DKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* in_diff_blob = BnInOp2Blob("in_diff");
......@@ -21,24 +21,30 @@ void AveragePoolingKernel<device_type, T>::BackwardDataContent(
Memset<device_type>(ctx.device_ctx, in_diff_blob->mut_dptr(), 0,
in_diff_blob->ByteSizeOfDataContentField());
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
AveragePoolingKernelUtil<device_type, T>::Backward(
ctx, out_diff_blob, in_diff_blob, this->pooling_ctx());
AveragePooling2DKernelUtil<device_type, T>::Backward(
ctx, out_diff_blob, in_diff_blob, this->pooling_2d_ctx());
}
template<DeviceType device_type, typename T>
const PoolingKernelConf&
AveragePoolingKernel<device_type, T>::GetPoolingKernelConf() const {
return this->kernel_conf().average_pooling_conf().pooling_conf();
const Pooling2DKernelConf&
AveragePooling2DKernel<device_type, T>::GetPooling2DKernelConf() const {
return this->kernel_conf().average_pooling_2d_conf().pooling_2d_conf();
}
template<DeviceType device_type, typename T>
const PbMessage& AveragePooling2DKernel<device_type, T>::GetPooling2DOpConf()
const {
return this->op_conf().average_pooling_2d_conf();
}
template<typename T>
class AveragePoolingKernelUtil<DeviceType::kCPU, T> final {
class AveragePooling2DKernelUtil<DeviceType::kCPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePoolingKernelUtil);
AveragePoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(AveragePooling2DKernelUtil);
AveragePooling2DKernelUtil() = delete;
static void Forward(const KernelCtx& ctx, const Blob* in_blob, Blob* out_blob,
const PoolingCtx& pooling_ctx) {
const Pooling2DCtx& pooling_ctx) {
const T* in_dptr = in_blob->dptr<T>();
T* out_dptr = out_blob->mut_dptr<T>();
......@@ -78,7 +84,7 @@ class AveragePoolingKernelUtil<DeviceType::kCPU, T> final {
}
static void Backward(const KernelCtx& ctx, const Blob* out_diff_blob,
Blob* in_diff_blob, const PoolingCtx& pooling_ctx) {
Blob* in_diff_blob, const Pooling2DCtx& pooling_ctx) {
const T* out_diff_dptr = out_diff_blob->dptr<T>();
T* in_diff_dptr = in_diff_blob->mut_dptr<T>();
......@@ -120,7 +126,7 @@ class AveragePoolingKernelUtil<DeviceType::kCPU, T> final {
}
};
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kAveragePoolingConf,
AveragePoolingKernel, ARITHMETIC_DATA_TYPE_SEQ);
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kAveragePooling2DConf,
AveragePooling2DKernel, ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/average_pooling_kernel.h"
#include "oneflow/core/kernel/average_pooling_2d_kernel.h"
namespace oneflow {
namespace {
template<typename T>
__global__ void AveragePoolForward(const int64_t nthreads, const T* in_dptr,
T* out_dptr, const int64_t channels,
const int64_t height, const int64_t width,
const int64_t pooled_height,
const int64_t pooled_width,
const PoolingCtx ctx) {
__global__ void AveragePooling2DForward(
const int64_t nthreads, const T* in_dptr, T* out_dptr,
const int64_t channels, const int64_t height, const int64_t width,
const int64_t pooled_height, const int64_t pooled_width,
const Pooling2DCtx ctx) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int64_t pw = index % pooled_width;
const int64_t ph = (index / pooled_width) % pooled_height;
......@@ -43,10 +42,10 @@ __global__ void AveragePoolForward(const int64_t nthreads, const T* in_dptr,
}
template<typename T>
__global__ void AveragePoolBackward(
__global__ void AveragePooling2DBackward(
const int64_t nthreads, const T* out_diff_dptr, T* in_diff_dptr,
const int64_t channels, const int64_t height, const int64_t width,
const int64_t pooled_height, const int64_t pooled_width, PoolingCtx ctx) {
const int64_t pooled_height, const int64_t pooled_width, Pooling2DCtx ctx) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int64_t w = index % width + ctx.padding_left;
const int64_t h = (index / width) % height + ctx.padding_top;
......@@ -86,41 +85,39 @@ __global__ void AveragePoolBackward(
} // namespace
template<typename T>
class AveragePoolingKernelUtil<DeviceType::kGPU, T> final {
class AveragePooling2DKernelUtil<DeviceType::kGPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePoolingKernelUtil);
AveragePoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(AveragePooling2DKernelUtil);
AveragePooling2DKernelUtil() = delete;
static void Forward(const KernelCtx& ctx, const Blob* in_blob, Blob* out_blob,
const PoolingCtx& pooling_ctx) {
const Pooling2DCtx& pooling_ctx) {
const int64_t count = out_blob->shape().elem_cnt();
PoolingCtx cuda_ctx = pooling_ctx;
AveragePoolForward<T>
AveragePooling2DForward<T>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
count, in_blob->dptr<T>(), out_blob->mut_dptr<T>(),
in_blob->shape().At(1), in_blob->shape().At(2),
in_blob->shape().At(3), out_blob->shape().At(2),
out_blob->shape().At(3), cuda_ctx);
out_blob->shape().At(3), pooling_ctx);
}
static void Backward(const KernelCtx& ctx, const Blob* out_diff_blob,
Blob* in_diff_blob, const PoolingCtx& pooling_ctx) {
Blob* in_diff_blob, const Pooling2DCtx& pooling_ctx) {
const int64_t count = in_diff_blob->shape().elem_cnt();
PoolingCtx cuda_ctx = pooling_ctx;
AveragePoolBackward<T>
AveragePooling2DBackward<T>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
count, out_diff_blob->dptr<T>(), in_diff_blob->mut_dptr<T>(),
in_diff_blob->shape().At(1), in_diff_blob->shape().At(2),
in_diff_blob->shape().At(3), out_diff_blob->shape().At(2),
out_diff_blob->shape().At(3), cuda_ctx);
out_diff_blob->shape().At(3), pooling_ctx);
}
};
#define INSTANTIATE_AVERAGE_POOLING_KERNEL_UTIL(type_cpp, type_proto) \
template class AveragePoolingKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_AVERAGE_POOLING_KERNEL_UTIL,
#define INSTANTIATE_AVERAGE_POOLING_2D_KERNEL_UTIL(type_cpp, type_proto) \
template class AveragePooling2DKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_AVERAGE_POOLING_2D_KERNEL_UTIL,
ARITHMETIC_DATA_TYPE_SEQ)
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_2D_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_2D_KERNEL_H_
#include "oneflow/core/kernel/pooling_kernel.h"
#include "oneflow/core/kernel/pooling_2d_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class AveragePoolingKernel final : public PoolingKernel<device_type> {
class AveragePooling2DKernel final : public Pooling2DKernel<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePoolingKernel);
AveragePoolingKernel() = default;
~AveragePoolingKernel() = default;
OF_DISALLOW_COPY_AND_MOVE(AveragePooling2DKernel);
AveragePooling2DKernel() = default;
~AveragePooling2DKernel() = default;
private:
void ForwardDataContent(
......@@ -19,20 +19,23 @@ class AveragePoolingKernel final : public PoolingKernel<device_type> {
void BackwardDataContent(
const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
const PoolingKernelConf& GetPoolingKernelConf() const override;
const Pooling2DKernelConf& GetPooling2DKernelConf() const override;
const PbMessage& GetPooling2DOpConf() const override;
};
template<DeviceType device_type, typename T>
class AveragePoolingKernelUtil {
class AveragePooling2DKernelUtil {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePoolingKernelUtil);
AveragePoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(AveragePooling2DKernelUtil);
AveragePooling2DKernelUtil() = delete;
static void Forward(const KernelCtx&, const Blob*, Blob*, const PoolingCtx&);
static void Forward(const KernelCtx&, const Blob*, Blob*,
const Pooling2DCtx&);
static void Backward(const KernelCtx&, const Blob*, Blob*, const PoolingCtx&);
static void Backward(const KernelCtx&, const Blob*, Blob*,
const Pooling2DCtx&);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_KERNEL_H_
#endif // ONEFLOW_CORE_KERNEL_AVERAGE_POOLING_2D_KERNEL_H_
......@@ -20,19 +20,19 @@ message SoftmaxLossKernelConf {
required DataType label_type = 2;
}
message PoolingKernelConf {
message Pooling2DKernelConf {
optional int32 padding_top = 1 [default = 0];
optional int32 padding_bottom = 2 [default = 0];
optional int32 padding_left = 3 [default = 0];
optional int32 padding_right = 4 [default = 0];
}
message AveragePoolingKernelConf {
required PoolingKernelConf pooling_conf = 1;
message AveragePooling2DKernelConf {
required Pooling2DKernelConf pooling_2d_conf = 1;
}
message MaxPoolingKernelConf {
required PoolingKernelConf pooling_conf = 1;
message MaxPooling2DKernelConf {
required Pooling2DKernelConf pooling_2d_conf = 1;
}
message ReduceSumKernelConf {
......@@ -64,7 +64,7 @@ message KernelConf {
ConcatKernelConf concat_conf = 113;
SoftmaxLossKernelConf softmax_loss_conf = 117;
ReduceSumKernelConf reduce_sum_conf = 120;
AveragePoolingKernelConf average_pooling_conf = 200;
MaxPoolingKernelConf max_pooling_conf = 201;
AveragePooling2DKernelConf average_pooling_2d_conf = 200;
MaxPooling2DKernelConf max_pooling_2d_conf = 201;
}
}
#include "oneflow/core/kernel/max_pooling_kernel.h"
#include "oneflow/core/kernel/max_pooling_2d_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void MaxPoolingKernel<device_type, T>::ForwardDataContent(
void MaxPooling2DKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
Blob* out_blob = BnInOp2Blob("out");
Blob* idx_blob = BnInOp2Blob("idx");
MaxPoolingKernelUtil<device_type, T>::Forward(ctx, in_blob, out_blob,
idx_blob, this->pooling_ctx());
MaxPooling2DKernelUtil<device_type, T>::Forward(
ctx, in_blob, out_blob, idx_blob, this->pooling_2d_ctx());
}
template<DeviceType device_type, typename T>
void MaxPoolingKernel<device_type, T>::BackwardDataContent(
void MaxPooling2DKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* in_diff_blob = BnInOp2Blob("in_diff");
......@@ -23,24 +23,30 @@ void MaxPoolingKernel<device_type, T>::BackwardDataContent(
in_diff_blob->ByteSizeOfDataContentField());
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const Blob* idx_blob = BnInOp2Blob("idx");
MaxPoolingKernelUtil<device_type, T>::Backward(
ctx, out_diff_blob, idx_blob, in_diff_blob, this->pooling_ctx());
MaxPooling2DKernelUtil<device_type, T>::Backward(
ctx, out_diff_blob, idx_blob, in_diff_blob, this->pooling_2d_ctx());
}
template<DeviceType device_type, typename T>
const PoolingKernelConf&
MaxPoolingKernel<device_type, T>::GetPoolingKernelConf() const {
return this->kernel_conf().average_pooling_conf().pooling_conf();
const Pooling2DKernelConf&
MaxPooling2DKernel<device_type, T>::GetPooling2DKernelConf() const {
return this->kernel_conf().max_pooling_2d_conf().pooling_2d_conf();
}
template<DeviceType device_type, typename T>
const PbMessage& MaxPooling2DKernel<device_type, T>::GetPooling2DOpConf()
const {
return this->op_conf().max_pooling_2d_conf();
}
template<typename T>
class MaxPoolingKernelUtil<DeviceType::kCPU, T> final {
class MaxPooling2DKernelUtil<DeviceType::kCPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(MaxPoolingKernelUtil);
MaxPoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(MaxPooling2DKernelUtil);
MaxPooling2DKernelUtil() = delete;
static void Forward(const KernelCtx& ctx, const Blob* in_blob, Blob* out_blob,
Blob* idx_blob, const PoolingCtx& pooling_ctx) {
Blob* idx_blob, const Pooling2DCtx& pooling_ctx) {
const T* in_dptr = in_blob->dptr<T>();
T* out_dptr = out_blob->mut_dptr<T>();
uint32_t* idx_dptr = idx_blob->mut_dptr<uint32_t>();
......@@ -83,7 +89,7 @@ class MaxPoolingKernelUtil<DeviceType::kCPU, T> final {
static void Backward(const KernelCtx& ctx, const Blob* out_diff_blob,
const Blob* idx_blob, Blob* in_diff_blob,
const PoolingCtx& pooling_ctx) {
const Pooling2DCtx& pooling_ctx) {
const T* out_diff_dptr = out_diff_blob->dptr<T>();
const uint32_t* idx_dptr = idx_blob->dptr<uint32_t>();
T* in_diff_dptr = in_diff_blob->mut_dptr<T>();
......@@ -106,7 +112,7 @@ class MaxPoolingKernelUtil<DeviceType::kCPU, T> final {
}
};
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMaxPoolingConf, MaxPoolingKernel,
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMaxPooling2DConf, MaxPooling2DKernel,
ARITHMETIC_DATA_TYPE_SEQ);
} // namespace oneflow
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/kernel/max_pooling_kernel.h"
#include "oneflow/core/kernel/max_pooling_2d_kernel.h"
namespace oneflow {
namespace {
template<typename T>
__global__ void MaxPoolForward(const int64_t nthreads, const T* in_dptr,
T* out_dptr, uint32_t* mask_dptr,
const int64_t channels, const int64_t height,
const int64_t width, const int64_t pooled_height,
const int64_t pooled_width, PoolingCtx ctx) {
__global__ void MaxPooling2DForward(
const int64_t nthreads, const T* in_dptr, T* out_dptr, uint32_t* mask_dptr,
const int64_t channels, const int64_t height, const int64_t width,
const int64_t pooled_height, const int64_t pooled_width, Pooling2DCtx ctx) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int64_t pw = index % pooled_width;
const int64_t ph = (index / pooled_width) % pooled_height;
......@@ -43,12 +42,11 @@ __global__ void MaxPoolForward(const int64_t nthreads, const T* in_dptr,
}
template<typename T>
__global__ void MaxPoolBackward(const int64_t nthreads, const T* out_diff_dptr,
const uint32_t* mask_dptr, T* in_diff_dptr,
const int64_t channels, const int64_t height,
const int64_t width,
const int64_t pooled_height,
const int64_t pooled_width, PoolingCtx ctx) {
__global__ void MaxPooling2DBackward(
const int64_t nthreads, const T* out_diff_dptr, const uint32_t* mask_dptr,
T* in_diff_dptr, const int64_t channels, const int64_t height,
const int64_t width, const int64_t pooled_height,
const int64_t pooled_width, Pooling2DCtx ctx) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int64_t w = index % width;
const int64_t h = (index / width) % height;
......@@ -88,40 +86,41 @@ __global__ void MaxPoolBackward(const int64_t nthreads, const T* out_diff_dptr,
} // namespace
template<typename T>
class MaxPoolingKernelUtil<DeviceType::kGPU, T> final {
class MaxPooling2DKernelUtil<DeviceType::kGPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(MaxPoolingKernelUtil);
MaxPoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(MaxPooling2DKernelUtil);
MaxPooling2DKernelUtil() = delete;
static void Forward(const KernelCtx& ctx, const Blob* in_blob, Blob* out_blob,
Blob* mask_blob, const PoolingCtx& pooling_ctx) {
Blob* mask_blob, const Pooling2DCtx& pooling_ctx) {
const int64_t count = out_blob->shape().elem_cnt();
PoolingCtx cuda_ctx = pooling_ctx;
MaxPoolForward<T><<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
count, in_blob->dptr<T>(), out_blob->mut_dptr<T>(),
mask_blob->mut_dptr<uint32_t>(), in_blob->shape().At(1),
in_blob->shape().At(2), in_blob->shape().At(3), out_blob->shape().At(2),
out_blob->shape().At(3), cuda_ctx);
MaxPooling2DForward<T>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
count, in_blob->dptr<T>(), out_blob->mut_dptr<T>(),
mask_blob->mut_dptr<uint32_t>(), in_blob->shape().At(1),
in_blob->shape().At(2), in_blob->shape().At(3),
out_blob->shape().At(2), out_blob->shape().At(3), pooling_ctx);
}
static void Backward(const KernelCtx& ctx, const Blob* out_diff_blob,
const Blob* mask_blob, Blob* in_diff_blob,
const PoolingCtx& pooling_ctx) {
const Pooling2DCtx& pooling_ctx) {
const int64_t count = in_diff_blob->shape().elem_cnt();
PoolingCtx cuda_ctx = pooling_ctx;
MaxPoolBackward<T><<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock,
0, ctx.device_ctx->cuda_stream()>>>(
count, out_diff_blob->dptr<T>(), mask_blob->dptr<uint32_t>(),
in_diff_blob->mut_dptr<T>(), in_diff_blob->shape().At(1),
in_diff_blob->shape().At(2), in_diff_blob->shape().At(3),
out_diff_blob->shape().At(2), out_diff_blob->shape().At(3), cuda_ctx);
MaxPooling2DBackward<T>
<<<BlocksNum4ThreadsNum(count), kCudaThreadsNumPerBlock, 0,
ctx.device_ctx->cuda_stream()>>>(
count, out_diff_blob->dptr<T>(), mask_blob->dptr<uint32_t>(),
in_diff_blob->mut_dptr<T>(), in_diff_blob->shape().At(1),
in_diff_blob->shape().At(2), in_diff_blob->shape().At(3),
out_diff_blob->shape().At(2), out_diff_blob->shape().At(3),
pooling_ctx);
}
};
#define INSTANTIATE_MAX_POOLING_KERNEL_UTIL(type_cpp, type_proto) \
template class MaxPoolingKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_MAX_POOLING_KERNEL_UTIL,
#define INSTANTIATE_MAX_POOLING_2D_KERNEL_UTIL(type_cpp, type_proto) \
template class MaxPooling2DKernelUtil<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_MAX_POOLING_2D_KERNEL_UTIL,
ARITHMETIC_DATA_TYPE_SEQ)
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_MAX_POOLING_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_MAX_POOLING_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_MAX_POOLING_2D_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_MAX_POOLING_2D_KERNEL_H_
#include "oneflow/core/kernel/pooling_kernel.h"
#include "oneflow/core/kernel/pooling_2d_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class MaxPoolingKernel final : public PoolingKernel<device_type> {
class MaxPooling2DKernel final : public Pooling2DKernel<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(MaxPoolingKernel);
MaxPoolingKernel() = default;
~MaxPoolingKernel() = default;
OF_DISALLOW_COPY_AND_MOVE(MaxPooling2DKernel);
MaxPooling2DKernel() = default;
~MaxPooling2DKernel() = default;
private:
void ForwardDataContent(
......@@ -19,22 +19,23 @@ class MaxPoolingKernel final : public PoolingKernel<device_type> {
void BackwardDataContent(
const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
const PoolingKernelConf& GetPoolingKernelConf() const override;
const Pooling2DKernelConf& GetPooling2DKernelConf() const override;
const PbMessage& GetPooling2DOpConf() const override;
};
template<DeviceType device_type, typename T>
class MaxPoolingKernelUtil {
class MaxPooling2DKernelUtil {
public:
OF_DISALLOW_COPY_AND_MOVE(MaxPoolingKernelUtil);
MaxPoolingKernelUtil() = delete;
OF_DISALLOW_COPY_AND_MOVE(MaxPooling2DKernelUtil);
MaxPooling2DKernelUtil() = delete;
static void Forward(const KernelCtx&, const Blob*, Blob*, Blob*,
const PoolingCtx&);
const Pooling2DCtx&);
static void Backward(const KernelCtx&, const Blob*, const Blob*, Blob*,
const PoolingCtx&);
const Pooling2DCtx&);
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_MAX_POOLING_KERNEL_H_
#endif // ONEFLOW_CORE_KERNEL_MAX_POOLING_2D_KERNEL_H_
#include "oneflow/core/kernel/pooling_kernel.h"
#include "oneflow/core/kernel/pooling_2d_kernel.h"
namespace oneflow {
PoolingCtx BuildPoolingCtx(const PbMessage& op_conf,
const PoolingKernelConf& kernel_conf) {
PoolingCtx ctx;
Pooling2DCtx BuildPooling2DCtx(const PbMessage& op_conf,
const Pooling2DKernelConf& kernel_conf) {
Pooling2DCtx ctx;
ctx.pool_size_h = GetInt32FromPbMessage(op_conf, "pool_size_h");
ctx.pool_size_w = GetInt32FromPbMessage(op_conf, "pool_size_w");
ctx.strides_h = GetInt32FromPbMessage(op_conf, "strides_h");
......
#ifndef ONEFLOW_CORE_KERNEL_POOLING_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_POOLING_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_POOLING_2D_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_POOLING_2D_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
struct PoolingCtx {
struct Pooling2DCtx {
int32_t pool_size_h;
int32_t pool_size_w;
int32_t strides_h;
......@@ -17,26 +17,28 @@ struct PoolingCtx {
};
template<DeviceType device_type>
class PoolingKernel : public KernelIf<device_type> {
class Pooling2DKernel : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(PoolingKernel);
PoolingKernel() = default;
virtual ~PoolingKernel() = default;
OF_DISALLOW_COPY_AND_MOVE(Pooling2DKernel);
Pooling2DKernel() = default;
virtual ~Pooling2DKernel() = default;
protected:
void VirtualKernelInit(const ParallelContext*) override {
pooling_ctx_ = BuildPoolingCtx(this->op_conf(), GetPoolingKernelConf());
pooling_2d_ctx_ =
BuildPooling2DCtx(GetPooling2DOpConf(), GetPooling2DKernelConf());
}
const PoolingCtx& pooling_ctx() const { return pooling_ctx_; }
virtual const PoolingKernelConf& GetPoolingKernelConf() const = 0;
const Pooling2DCtx& pooling_2d_ctx() const { return pooling_2d_ctx_; }
virtual const Pooling2DKernelConf& GetPooling2DKernelConf() const = 0;
virtual const PbMessage& GetPooling2DOpConf() const = 0;
private:
PoolingCtx pooling_ctx_;
Pooling2DCtx pooling_2d_ctx_;
};
PoolingCtx BuildPoolingCtx(const PbMessage& op_conf,
const PoolingKernelConf& kernel_conf);
Pooling2DCtx BuildPooling2DCtx(const PbMessage& op_conf,
const Pooling2DKernelConf& kernel_conf);
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_POOLING_KERNEL_H_
#endif // ONEFLOW_CORE_KERNEL_POOLING_2D_KERNEL_H_
#include "oneflow/core/operator/average_pooling_2d_op.h"
namespace oneflow {
const PbMessage& AveragePooling2DOp::GetSpecialConf() const {
return op_conf().average_pooling_2d_conf();
}
Pooling2DKernelConf* AveragePooling2DOp::GetMutPooling2DKernelConf(
KernelConf* kernel_conf) const {
return kernel_conf->mutable_average_pooling_2d_conf()
->mutable_pooling_2d_conf();
}
REGISTER_OP(OperatorConf::kAveragePooling2DConf, AveragePooling2DOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_2D_OP_H_
#define ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_2D_OP_H_
#include "oneflow/core/operator/pooling_2d_op.h"
namespace oneflow {
class AveragePooling2DOp final : public Pooling2DOp {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePooling2DOp);
AveragePooling2DOp() = default;
~AveragePooling2DOp() = default;
const PbMessage& GetSpecialConf() const override;
private:
Pooling2DKernelConf* GetMutPooling2DKernelConf(KernelConf*) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_2D_OP_H_
#include "oneflow/core/operator/average_pooling_op.h"
namespace oneflow {
const PbMessage& AveragePoolingOp::GetSpecialConf() const {
return op_conf().average_pooling_conf();
}
PoolingKernelConf* AveragePoolingOp::GetMutPoolingKernelConf(
KernelConf* kernel_conf) const {
return kernel_conf->mutable_average_pooling_conf()->mutable_pooling_conf();
}
REGISTER_OP(OperatorConf::kAveragePoolingConf, AveragePoolingOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_OP_H_
#define ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_OP_H_
#include "oneflow/core/operator/pooling_op.h"
namespace oneflow {
class AveragePoolingOp final : public PoolingOp {
public:
OF_DISALLOW_COPY_AND_MOVE(AveragePoolingOp);
AveragePoolingOp() = default;
~AveragePoolingOp() = default;
const PbMessage& GetSpecialConf() const override;
private:
PoolingKernelConf* GetMutPoolingKernelConf(KernelConf*) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_AVERAGE_POOLING_OP_H_
#include "oneflow/core/operator/max_pooling_op.h"
#include "oneflow/core/operator/max_pooling_2d_op.h"
namespace oneflow {
const PbMessage& MaxPoolingOp::GetSpecialConf() const {
return op_conf().max_pooling_conf();
const PbMessage& MaxPooling2DOp::GetSpecialConf() const {
return op_conf().max_pooling_2d_conf();
}
void MaxPoolingOp::VirtualEnrollDataTmpBn() { EnrollDataTmpBn("idx"); }
void MaxPooling2DOp::VirtualEnrollDataTmpBn() { EnrollDataTmpBn("idx"); }
void MaxPoolingOp::VirtualInferDataTmpBlobDesc(
void MaxPooling2DOp::VirtualInferDataTmpBlobDesc(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp) const {
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
BlobDesc* idx_blob_desc = GetBlobDesc4BnInOp("idx");
......@@ -16,11 +16,11 @@ void MaxPoolingOp::VirtualInferDataTmpBlobDesc(
idx_blob_desc->set_data_type(DataType::kUInt32);
}
PoolingKernelConf* MaxPoolingOp::GetMutPoolingKernelConf(
Pooling2DKernelConf* MaxPooling2DOp::GetMutPooling2DKernelConf(
KernelConf* kernel_conf) const {
return kernel_conf->mutable_max_pooling_conf()->mutable_pooling_conf();
return kernel_conf->mutable_max_pooling_2d_conf()->mutable_pooling_2d_conf();
}
REGISTER_OP(OperatorConf::kMaxPoolingConf, MaxPoolingOp);
REGISTER_OP(OperatorConf::kMaxPooling2DConf, MaxPooling2DOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_MAX_POOLING_OP_H_
#define ONEFLOW_CORE_OPERATOR_MAX_POOLING_OP_H_
#ifndef ONEFLOW_CORE_OPERATOR_MAX_POOLING_2D_OP_H_
#define ONEFLOW_CORE_OPERATOR_MAX_POOLING_2D_OP_H_
#include "oneflow/core/operator/pooling_op.h"
#include "oneflow/core/operator/pooling_2d_op.h"
namespace oneflow {
class MaxPoolingOp final : public PoolingOp {
class MaxPooling2DOp final : public Pooling2DOp {
public:
OF_DISALLOW_COPY_AND_MOVE(MaxPoolingOp);
MaxPoolingOp() = default;
~MaxPoolingOp() = default;
OF_DISALLOW_COPY_AND_MOVE(MaxPooling2DOp);
MaxPooling2DOp() = default;
~MaxPooling2DOp() = default;
const PbMessage& GetSpecialConf() const override;
......@@ -17,9 +17,9 @@ class MaxPoolingOp final : public PoolingOp {
void VirtualEnrollDataTmpBn() override;
void VirtualInferDataTmpBlobDesc(std::function<BlobDesc*(const std::string)>
GetBlobDesc4BnInOp) const override;
PoolingKernelConf* GetMutPoolingKernelConf(KernelConf*) const override;
Pooling2DKernelConf* GetMutPooling2DKernelConf(KernelConf*) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_MAX_POOLING_OP_H_
#endif // ONEFLOW_CORE_OPERATOR_MAX_POOLING_2D_OP_H_
......@@ -83,7 +83,7 @@ message BasicDataLoaderOpConf {
required ShapeProto shape = 5;
}
message AveragePoolingOpConf {
message AveragePooling2DOpConf {
required string in = 1;
required string out = 2;
......@@ -94,7 +94,7 @@ message AveragePoolingOpConf {
optional int32 strides_w = 7 [default = 1];
}
message MaxPoolingOpConf {
message MaxPooling2DOpConf {
required string in = 1;
required string out = 2;
......@@ -277,8 +277,8 @@ message OperatorConf {
LossPrintOpConf loss_print_conf = 119;
ReduceSumOpConf reduce_sum_conf = 120;
RecurrentOpConf recurrent_conf = 121;
AveragePoolingOpConf average_pooling_conf = 200;
MaxPoolingOpConf max_pooling_conf = 201;
AveragePooling2DOpConf average_pooling_2d_conf = 200;
MaxPooling2DOpConf max_pooling_2d_conf = 201;
}
}
......
#include "oneflow/core/operator/pooling_op.h"
#include "oneflow/core/operator/pooling_2d_op.h"
namespace oneflow {
void PoolingOp::InitFromOpConf() {
void Pooling2DOp::InitFromOpConf() {
std::string padding_mthd = GetStringFromSpecialConf("padding");
std::transform(padding_mthd.begin(), padding_mthd.end(), padding_mthd.begin(),
::tolower);
if (padding_mthd != "same" || padding_mthd != "valid") {
if (padding_mthd != "same" && padding_mthd != "valid") {
LOG(FATAL) << "Invalid padding method in " << op_name();
}
SetStringInSpecialConf("padding", padding_mthd);
......@@ -16,7 +16,7 @@ void PoolingOp::InitFromOpConf() {
VirtualEnrollDataTmpBn();
}
void PoolingOp::InferBlobDescs(
void Pooling2DOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
// in
......@@ -36,7 +36,7 @@ void PoolingOp::InferBlobDescs(
VirtualInferDataTmpBlobDesc(GetBlobDesc4BnInOp);
}
void PoolingOp::VirtualGenKernelConf(
void Pooling2DOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
std::string padding_mthd = GetStringFromSpecialConf("padding");
......@@ -50,7 +50,7 @@ void PoolingOp::VirtualGenKernelConf(
const int32_t padding_needed_w =
(std::get<1>(out_size) - 1) * GetInt32FromSpecialConf("strides_w")
+ GetInt32FromSpecialConf("pool_size_w") - in_blob_desc->shape().At(3);
PoolingKernelConf* pooling_conf = GetMutPoolingKernelConf(kernel_conf);
Pooling2DKernelConf* pooling_conf = GetMutPooling2DKernelConf(kernel_conf);
pooling_conf->set_padding_top(padding_needed_h / 2);
pooling_conf->set_padding_bottom(padding_needed_h - padding_needed_h / 2);
pooling_conf->set_padding_left(padding_needed_w / 2);
......@@ -58,8 +58,8 @@ void PoolingOp::VirtualGenKernelConf(
}
}
std::tuple<int32_t, int32_t> PoolingOp::CalcOutSize(int32_t in_h,
int32_t in_w) const {
std::tuple<int32_t, int32_t> Pooling2DOp::CalcOutSize(int32_t in_h,
int32_t in_w) const {
int32_t pool_size_h = GetInt32FromSpecialConf("pool_size_h");
int32_t pool_size_w = GetInt32FromSpecialConf("pool_size_w");
int32_t strides_h = GetInt32FromSpecialConf("strides_h");
......
#ifndef ONEFLOW_CORE_OPERATOR_POOLING_OP_H_
#define ONEFLOW_CORE_OPERATOR_POOLING_OP_H_
#ifndef ONEFLOW_CORE_OPERATOR_POOLING_2D_OP_H_
#define ONEFLOW_CORE_OPERATOR_POOLING_2D_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class PoolingOp : public Operator {
class Pooling2DOp : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(PoolingOp);
PoolingOp() = default;
virtual ~PoolingOp() = default;
OF_DISALLOW_COPY_AND_MOVE(Pooling2DOp);
Pooling2DOp() = default;
virtual ~Pooling2DOp() = default;
void InitFromOpConf() override;
......@@ -28,7 +28,7 @@ class PoolingOp : public Operator {
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
KernelConf* kernel_conf) const override;
virtual PoolingKernelConf* GetMutPoolingKernelConf(KernelConf*) const = 0;
virtual Pooling2DKernelConf* GetMutPooling2DKernelConf(KernelConf*) const = 0;
private:
std::tuple<int, int> CalcOutSize(int32_t in_h, int32_t in_w) const;
......@@ -36,4 +36,4 @@ class PoolingOp : public Operator {
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_POOLING_OP_H_
#endif // ONEFLOW_CORE_OPERATOR_POOLING_2D_OP_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册