提交 d152ec5f 编写于 作者: Z zhaojiaying01

add fusion_conv_relu op for CPU and GPU_CL

上级 b025553b
......@@ -31,6 +31,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU = "fusion_conv_add_add_prelu";
const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU = "fusion_conv_add_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU = "fusion_conv_bn_add_relu";
const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU = "fusion_dwconv_bn_relu";
const char *G_OP_TYPE_FUSION_CONV_RELU = "fusion_conv_relu";
const char *G_OP_TYPE_FUSION_CONV_BN_RELU = "fusion_conv_bn_relu";
const char *G_OP_TYPE_FC = "fusion_fc";
const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
......@@ -125,6 +126,7 @@ std::unordered_map<
op_input_output_key = {
{G_OP_TYPE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FUSION_DWCONV_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_BN_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_PRELU, {{"X", "Alpha"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}},
......
......@@ -151,6 +151,7 @@ extern const char *G_OP_TYPE_FUSION_CONV_ADD_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_BN_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_DWCONV_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_BN_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_RELU;
extern const char *G_OP_TYPE_GRU;
extern const char *G_OP_TYPE_GRU_UNIT;
......
......@@ -168,6 +168,10 @@ LOAD_FUSION_MATCHER(fusion_conv_bn_add_relu);
LOAD_OP3(fusion_conv_bn_relu, CPU, GPU_CL, FPGA);
LOAD_FUSION_MATCHER(fusion_conv_bn_relu);
#endif
#ifdef FUSION_CONVRELU_OP
LOAD_OP2(fusion_conv_relu, CPU, GPU_CL);
LOAD_FUSION_MATCHER(fusion_conv_relu);
#endif
#ifdef GRU_OP
LOAD_OP1(gru, CPU);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVRELU_OP
#include "operators/fusion_conv_relu_op.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionConvReluOp<Dtype, T>::InferShape() const {
auto in_dims = this->param_.Input()->dims();
auto filter_dims = this->param_.Filter()->dims();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
int groups = this->param_.Groups();
std::vector<int> dilations = this->param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_conv_relu, ops::FusionConvReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_relu, ops::FusionConvReluOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_conv_relu, ops::FusionConvReluOp);
#endif
#if defined(PADDLE_MOBILE_FPGA) || defined(PADDLE_MOBILE_FPGA_KD)
REGISTER_OPERATOR_FPGA(fusion_conv_relu, ops::FusionConvReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/conv_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionConvReluMatcher : public framework::FusionOpMatcher {
public:
FusionConvReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(), {}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_CONV_RELU; }
};
template <typename DeviceType, typename T>
class FusionConvReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvReluParam<DeviceType>,
operators::ConvReluKernel<DeviceType, T>> {
public:
FusionConvReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
framework::Scope *scope)
: framework::OperatorWithKernel<DeviceType,
FusionConvReluParam<DeviceType>,
operators::ConvReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "operators/kernel/activation_kernel.h"
#include "common/types.h"
#include "operators/kernel/central-arm-func/activation_arm_func.h"
#include "operators/math/activation.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......@@ -22,86 +23,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename Dtype, ActivationType Act>
struct ActivationCompute {
void operator()(const Tensor *input, Tensor *output) {}
void operator()(const Tensor *input, Tensor *output, float alpha) {}
};
template <ActivationType Act>
struct ActivationCompute<float, Act> {
void operator()(const Tensor *input, Tensor *output) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
float *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(local_y, r0);
vst1q_f32(local_y + 4, r1);
vst1q_f32(local_y + 8, r2);
vst1q_f32(local_y + 12, r3);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
y[i] = math::Active<Act>(x[i]);
}
}
void operator()(const Tensor *input, Tensor *output, float falpha) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
size_t remain = input->numel();
float alphas[4] = {falpha, falpha, falpha, falpha};
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
float *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
float32x4_t a_r0 = vld1q_f32(alphas);
float32x4_t a_r1 = vld1q_f32(alphas);
float32x4_t a_r2 = vld1q_f32(alphas);
float32x4_t a_r3 = vld1q_f32(alphas);
r0 = math::vActiveq_f32<Act>(r0, a_r0);
r1 = math::vActiveq_f32<Act>(r1, a_r1);
r2 = math::vActiveq_f32<Act>(r2, a_r2);
r3 = math::vActiveq_f32<Act>(r3, a_r3);
vst1q_f32(local_y, r0);
vst1q_f32(local_y + 4, r1);
vst1q_f32(local_y + 8, r2);
vst1q_f32(local_y + 12, r3);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
y[i] = math::Active<Act>(x[i], falpha);
}
}
};
#ifdef RELU_OP
template <>
bool ReluKernel<CPU, float>::Init(ReluParam<CPU> *param) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVRELU_OP
#include "operators/kernel/conv_relu_kernel.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/activation_arm_func.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvReluKernel<CPU, float>::Init(FusionConvReluParam<CPU> *param) {
InitBaseConvKernel(param);
return true;
}
template <>
void ConvReluKernel<CPU, float>::Compute(
const FusionConvReluParam<CPU> &param) {
switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
DepthwiseConv3x3<float, float>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
ActivationCompute<float, RELU>()(param.Output(), param.Output());
}
template class ConvReluKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/math/activation.h"
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile {
namespace operators {
template <typename Dtype, ActivationType Act>
struct ActivationCompute {
void operator()(const Tensor *input, Tensor *output) {}
void operator()(const Tensor *input, Tensor *output, float alpha) {}
};
template <ActivationType Act>
struct ActivationCompute<float, Act> {
void operator()(const Tensor *input, Tensor *output) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
float *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(local_y, r0);
vst1q_f32(local_y + 4, r1);
vst1q_f32(local_y + 8, r2);
vst1q_f32(local_y + 12, r3);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
y[i] = math::Active<Act>(x[i]);
}
}
void operator()(const Tensor *input, Tensor *output, float falpha) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
size_t remain = input->numel();
float alphas[4] = {falpha, falpha, falpha, falpha};
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
float *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
float32x4_t a_r0 = vld1q_f32(alphas);
float32x4_t a_r1 = vld1q_f32(alphas);
float32x4_t a_r2 = vld1q_f32(alphas);
float32x4_t a_r3 = vld1q_f32(alphas);
r0 = math::vActiveq_f32<Act>(r0, a_r0);
r1 = math::vActiveq_f32<Act>(r1, a_r1);
r2 = math::vActiveq_f32<Act>(r2, a_r2);
r3 = math::vActiveq_f32<Act>(r3, a_r3);
vst1q_f32(local_y, r0);
vst1q_f32(local_y + 4, r1);
vst1q_f32(local_y + 8, r2);
vst1q_f32(local_y + 12, r3);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
y[i] = math::Active<Act>(x[i], falpha);
}
}
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVRELU_OP
#include "operators/kernel/conv_relu_kernel.h"
#include "operators/kernel/cl/cl-kernel-func/conv_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvReluKernel<GPU_CL, float>::Init(FusionConvReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
DLOG << " init helper: " << &cl_helper_;
DLOG << " conv kernel add kernel ~ ";
DLOG << " width of one block: " << param->Filter()->dims()[3];
DLOG << " height of one block: " << param->Filter()->dims()[2];
DLOG << " filter dims: " << param->Filter()->dims();
const std::string conv_kernel_file = "conv_kernel.cl";
const std::string wino_kernel_file = "winograd_transform.cl";
const std::string build_options = "-DRELU";
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
param->ExecMode() = ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT;
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_1x1_spl", conv_kernel_file, build_options);
DLOG << "conv 1x1";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT;
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", conv_kernel_file,
build_options);
DLOG << "depth_conv 3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
// if (param->Strides()[0] == param->Strides()[1] &&
// param->Strides()[0] == 1 && param->Input()->dims()[2] >= 32) {
// param->ExecMode() = ConvParam<GPU_CL>::EXEC_WINOGRAD3X3_FLOAT;
// this->cl_helper_.AddKernel("winograd_filter_transform_2x2",
// wino_kernel_file, build_options);
// this->cl_helper_.AddKernel("winograd_input_transform_2x2",
// wino_kernel_file, build_options);
// this->cl_helper_.AddKernel("matmul", "matmul.cl", build_options);
// this->cl_helper_.AddKernel("winograd_output_transform_2x2",
// wino_kernel_file, build_options);
//
// winograd_transform_weight<4, 3>(&this->cl_helper_, param->Filter());
//
// } else {
param->ExecMode() = ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT;
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_3x3", conv_kernel_file, build_options);
// }
DLOG << "conv 3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvReluKernel<GPU_CL, float>::Compute(
const FusionConvReluParam<GPU_CL> &param) {
switch (param.ExecMode()) {
case ConvParam<GPU_CL>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<4, 3>(&this->cl_helper_, param, true);
break;
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW1x1_FLOAT:
case ConvParam<GPU_CL>::EXEC_SLIDINGWINDOW3x3_FLOAT:
case ConvParam<GPU_CL>::EXEC_DEPTHWISE3x3_FLOAT:
ConvAddBnRelu(&this->cl_helper_, param, true);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
}
template class ConvReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_CONVRELU_OP
#include <vector>
#include "framework/operator.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvReluKernel
: public OpKernelBase<DeviceType, FusionConvReluParam<DeviceType>> {
public:
void Compute(const FusionConvReluParam<DeviceType> &param);
bool Init(FusionConvReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -2240,6 +2240,22 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
#endif
#ifdef FUSION_CONVRELU_OP
template <typename Dtype>
class FusionConvReluParam : public ConvParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionConvReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
Scope *scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
this->output_ = OpParam::OutFrom<GType>(outputs, *scope);
}
};
#endif
#ifdef FUSION_CONVBNRELU_OP
template <typename Dtype>
class FusionConvBNReluParam : public ConvParam<Dtype> {
......
......@@ -311,6 +311,7 @@ if(NOT FOUND_MATCH)
set(FUSION_CONVADDADDPRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
set(FUSION_CONVBNRELU_OP ON)
set(FUSION_CONVRELU_OP ON)
set(FUSION_CONVBNADDRELU_OP ON)
set(PRELU_OP ON)
set(RESIZE_OP ON)
......@@ -484,6 +485,10 @@ if (FUSION_CONVBNRELU_OP)
add_definitions(-DFUSION_CONVBNRELU_OP)
endif()
if (FUSION_CONVRELU_OP)
add_definitions(-DFUSION_CONVRELU_OP)
endif()
if (FUSION_CONVBNADDRELU_OP)
add_definitions(-DFUSION_CONVBNADDRELU_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册