未验证 提交 fd7e6431 编写于 作者: Q qingqing01 提交者: GitHub

Convolution fusion operator. (#14449)

* Convolution fusion operator.
* Clean code
test=develop
上级 d7bd0361
......@@ -111,7 +111,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op")
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -34,7 +34,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
register_operators(EXCLUDES warpctc_op)
register_operators(EXCLUDES warpctc_op conv_fusion_op)
# warpctc_cudnn need cudnn 7 above
if (WITH_GPU)
......@@ -43,6 +43,8 @@ if (WITH_GPU)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
......@@ -43,26 +43,6 @@ using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
#if CUDNN_VERSION_MIN(6, 0, 5)
static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
#else
// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc.
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
#endif
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
......
......@@ -17,10 +17,31 @@ limitations under the License. */
#include <functional>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
#if CUDNN_VERSION_MIN(6, 0, 5)
static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
#else
// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc.
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
#endif
template <typename TAlgorithm>
class AlgorithmsCache {
public:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/operators/conv_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
// This fused conv follows the equation:
// y = act ( alpha1 * conv(x) + alpha2 * z + bias ).
// here, y is Output,
// x is Input,
// z is ResidualData,
// bias is Bias
class Conv2DFusionOpMaker : public Conv2DOpMaker {
protected:
void Apply() override {
AddAttr<std::string>(
"activation",
"The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
"'relux' , 'tanh', 'band_pass'")
.SetDefault("relu");
}
};
// TODO(qingqing): add gradient operator for conv2d_fusion
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T>
class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.Input<Tensor>("Bias");
PADDLE_ENFORCE(bias, "The bias should not be null.");
auto* residual = ctx.Input<Tensor>("ResidualData");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string activation = ctx.Attr<std::string>("activation");
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const T* bias_data = bias->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const T* residual_data = residual ? residual->data<T>() : output_data;
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedTensorDescriptor bias_desc;
ScopedConvolutionDescriptor conv_desc;
ScopedActivationDescriptor act_desc;
DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()));
// Now only support NCHW
std::vector<int> bias_dim = {1, static_cast<int>(output->dims()[1]), 1, 1};
cudnnTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
cudnnActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
if (activation == "identity") {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else if (!exhaustive_search) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return fwd_perf_stat[0].algo;
});
VLOG(3) << "choose algo " << algo;
}
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
// ------------------- cudnn conv+bias+act forward --------------------
ScalingParamType<T> alpha1 = 1.0f;
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
......@@ -225,17 +225,9 @@ $$
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
Apply();
}
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
{"Input", /*->*/ "Output"}};
}
};
void Conv3DOpMaker::Make() {
AddInput(
"Input",
......@@ -334,6 +326,7 @@ Example:
W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
$$
)DOC");
Apply();
}
void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -60,12 +61,27 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
void Make() final;
protected:
virtual void Apply() {}
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
void Make() final;
protected:
virtual void Apply() {}
};
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
{"Input", /*->*/ "Output"}};
}
};
class ConvOp : public framework::OperatorWithKernel {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
......@@ -81,6 +82,16 @@ enum class PoolingMode {
kAverageInclusive,
};
enum ActivationMode {
kNone, // activation identity
kSigmoid,
kRelu,
kRelu6,
kReluX,
kTanh,
kBandPass,
};
#if CUDNN_VERSION < 6000
#pragma message "CUDNN version under 6.0 is supported at best effort."
#pragma message "We strongly encourage you to move to 6.0 and above."
......@@ -120,6 +131,26 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
}
#endif // CUDNN_VERSION < 6000
inline ActivationMode StringToActivationMode(const std::string& str) {
if (str == "identity") {
return ActivationMode::kNone;
} else if (str == "sigmoid") {
return ActivationMode::kSigmoid;
} else if (str == "relu") {
return ActivationMode::kRelu;
} else if (str == "relu6") {
return ActivationMode::kRelu6;
} else if (str == "relux") {
return ActivationMode::kReluX;
} else if (str == "tanh") {
return ActivationMode::kTanh;
} else if (str == "bandpass") {
return ActivationMode::kBandPass;
} else {
PADDLE_THROW("Unknown activation string: %s", str);
}
}
template <typename T>
class CudnnDataType;
......@@ -368,6 +399,58 @@ class ScopedSpatialTransformerDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedSpatialTransformerDescriptor);
};
class ScopedActivationDescriptor {
public:
ScopedActivationDescriptor() {
PADDLE_ENFORCE(dynload::cudnnCreateActivationDescriptor(&desc_));
}
~ScopedActivationDescriptor() {
PADDLE_ENFORCE(dynload::cudnnDestroyActivationDescriptor(desc_));
}
template <typename T>
inline cudnnActivationDescriptor_t descriptor(
const std::string& act, double value_max = static_cast<double>(0.)) {
double relu_ceiling = 0.0;
ActivationMode activation_mode = StringToActivationMode(act);
cudnnActivationMode_t mode;
switch (activation_mode) {
#if CUDNN_VERSION >= 7100
case ActivationMode::kNone:
mode = CUDNN_ACTIVATION_IDENTITY;
break;
#endif
case ActivationMode::kRelu6:
relu_ceiling = 6.0;
mode = CUDNN_ACTIVATION_CLIPPED_RELU;
break;
case ActivationMode::kReluX:
relu_ceiling = value_max;
mode = CUDNN_ACTIVATION_CLIPPED_RELU;
break;
case ActivationMode::kRelu:
mode = CUDNN_ACTIVATION_RELU;
break;
case ActivationMode::kSigmoid:
mode = CUDNN_ACTIVATION_SIGMOID;
break;
case ActivationMode::kTanh:
mode = CUDNN_ACTIVATION_TANH;
break;
default:
PADDLE_THROW("unrecognized activation mode: %d .",
static_cast<int>(activation_mode));
}
CUDNN_ENFORCE(dynload::cudnnSetActivationDescriptor(
desc_, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling));
return desc_;
}
private:
cudnnActivationDescriptor_t desc_;
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
......
......@@ -155,6 +155,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \
__macro(cudnnSetConvolutionMathType); \
__macro(cudnnConvolutionBiasActivationForward); \
__macro(cudnnCreateCTCLossDescriptor); \
__macro(cudnnDestroyCTCLossDescriptor); \
__macro(cudnnGetCTCLossDescriptor); \
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from test_conv2d_op import conv2d_forward_naive
class TestConv2dFusionOp(OpTest):
def setUp(self):
self.op_type = "conv2d_fusion"
self.exhaustive_search = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.activation = 'relu'
self.add_bias = True
self.add_residual_data = True
self.init_group()
self.init_dilation()
self.init_test_case()
self.init_bias_residual()
self.init_activation()
self.set_search_method()
conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = conv2d_forward_naive(input, filter, self.groups,
conv2d_param).astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
if self.add_residual_data:
residual_data = np.random.random(output.shape).astype(self.dtype)
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
residual_data)
output += residual_data
if self.add_bias:
bias = np.random.random(self.filter_size[0]).astype(self.dtype)
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
output = output + bias.reshape((1, bias.size, 1, 1))
assert self.activation in ['relu', 'identity']
if self.activation == 'relu':
output = np.maximum(output, 0)
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'data_format': self.data_format,
'exhaustive_search': self.exhaustive_search,
'activation': self.activation
}
self.outputs = {'Output': output}
def testcuda(self):
return core.is_compiled_with_cuda()
def test_check_output(self):
if self.testcuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
pass
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
def init_bias_residual(self):
self.add_bias = True
self.add_residual_data = True
def init_activation(self):
self.activation = 'relu'
def set_search_method(self):
self.exhaustive_search = False
class TestWithoutResidual(TestConv2dFusionOp):
def init_bias_residual(self):
self.add_residual_data = False
class TestIdentityActivation(TestConv2dFusionOp):
def init_activation(self):
self.activation = 'identity'
class TestWithGroup(TestConv2dFusionOp):
def init_group(self):
self.groups = 3
class TestWithDilation(TestConv2dFusionOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 10, 10] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [2, 2]
def init_group(self):
self.groups = 3
class TestCUDNNExhaustiveSearch(TestConv2dFusionOp):
def set_search_method(self):
self.exhaustive_search = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册