提交 3f5ae250 编写于 作者: W willzhang4a58

template typename float_point_type

上级 7dfabfd9
......@@ -3,14 +3,14 @@
namespace oneflow {
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kCPU, floating_point_type>::Forward(
const KernelContext&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kCPU, floating_point_type>::Backward(
const KernelContext&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
......
......@@ -3,14 +3,14 @@
namespace oneflow {
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kGPU, floating_point_type>::Forward(
const KernelContext&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kGPU, floating_point_type>::Backward(
const KernelContext&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
......
......@@ -8,11 +8,11 @@
namespace oneflow {
template<DeviceType device_type, FloatingPointType floating_point_type>
template<DeviceType device_type, typename floating_point_type>
class ConvolutionKernel final {
};
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
class ConvolutionKernel<DeviceType::kCPU, floating_point_type> final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(ConvolutionKernel);
......@@ -23,7 +23,7 @@ class ConvolutionKernel<DeviceType::kCPU, floating_point_type> final : public Ke
void Backward(const KernelContext&, std::function<Blob*(const std::string&)>) const override;
};
template<FloatingPointType floating_point_type>
template<typename floating_point_type>
class ConvolutionKernel<DeviceType::kGPU, floating_point_type> final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(ConvolutionKernel);
......
......@@ -2,6 +2,12 @@
namespace oneflow {
void Kernel::InitFromOpProto(const OperatorProto& op_proto) {
Operator* op = CreateOp(op_proto.op_conf().op_type_case());
op->InitFromProto(op_proto);
op_.reset(op);
}
void Kernel::InitModelAndModelTmpBlobs(
const KernelContext& ctx,
std::function<Blob*(const std::string&)> Blob4BnInOp) const {
......
......@@ -21,11 +21,7 @@ class Kernel {
OF_DISALLOW_COPY_AND_MOVE(Kernel);
virtual ~Kernel() = default;
void InitFromOpProto(const OperatorProto& op_proto) {
Operator* op = CreateOp(op_proto.op_conf().op_type_case());
op->InitFromProto(op_proto);
op_.reset(op);
}
void InitFromOpProto(const OperatorProto& op_proto);
void InitModelAndModelTmpBlobs(
const KernelContext& ctx,
......@@ -57,12 +53,12 @@ using KernelWardFunc = void (Kernel::*)(
#define INSTANTIATE_CPU_KERNEL_CLASS(classname) \
char gInstantiationGuardCPU##classname; \
template class classname<DeviceType::kCPU, FloatingPointType::kFloat>; \
template class classname<DeviceType::kCPU, FloatingPointType::kDouble>;
template class classname<DeviceType::kCPU, float>; \
template class classname<DeviceType::kCPU, double>;
#define INSTANTIATE_GPU_KERNEL_CLASS(classname) \
char gInstantiationGuardGPU##classname; \
template class classname<DeviceType::kGPU, FloatingPointType::kFloat>; \
template class classname<DeviceType::kGPU, FloatingPointType::kDouble>;
template class classname<DeviceType::kGPU, float>; \
template class classname<DeviceType::kGPU, double>;
} // namespace oneflow
......
......@@ -68,10 +68,10 @@ struct GpuDoubleKernelRegister {
};
#define REGISTER_KERNEL(OpTypeCase, KernelType) \
static CpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, FloatingPointType::kFloat>> g_##KernelType##_cpu_float_regst_var; \
static CpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, FloatingPointType::kDouble>> g_##KernelType##_cpu_double_regst_var; \
static GpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, FloatingPointType::kFloat>> g_##KernelType##_gpu_float_regst_var; \
static GpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, FloatingPointType::kDouble>> g_##KernelType##_gpu_double_regst_var;
static CpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, float>> g_##KernelType##_cpu_float_regst_var; \
static CpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, double>> g_##KernelType##_cpu_double_regst_var; \
static GpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, float>> g_##KernelType##_gpu_float_regst_var; \
static GpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, double>> g_##KernelType##_gpu_double_regst_var;
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册