提交 064dc888 编写于 作者: X xzl

add the comments for .h file and code tiny modify

上级 36e7800a
...@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "ConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h" //#include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {
template <class T> template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(int outputSize, void operator()(const T* inputData,
const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
...@@ -44,13 +44,13 @@ public: ...@@ -44,13 +44,13 @@ public:
template <class T> template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvGradInputFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(int inputSize, void operator()(const T* outputGrad,
const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
...@@ -65,14 +65,13 @@ public: ...@@ -65,14 +65,13 @@ public:
template <class T> template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(int num_i, void operator()(const T* outputGrad,
int colDataSize,
const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
...@@ -87,7 +86,7 @@ public: ...@@ -87,7 +86,7 @@ public:
}; };
/* /*
* \brief Forward calculation of convolution. * \brief Forward calculation of depthwise convolution.
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvFunction : public ConvFunctionBase { class DepthwiseConvFunction : public ConvFunctionBase {
...@@ -126,11 +125,9 @@ public: ...@@ -126,11 +125,9 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
size_t outputSize = batchSize * outputChannels * outputHeight * outputWidth;
DepthwiseConvFunctor<Device, real> depthwiseConv; DepthwiseConvFunctor<Device, real> depthwiseConv;
depthwiseConv(outputSize, depthwiseConv(inputData,
inputData,
filterData, filterData,
batchSize, batchSize,
outputChannels, outputChannels,
...@@ -149,7 +146,7 @@ public: ...@@ -149,7 +146,7 @@ public:
}; };
/* /*
* \brief Backward input calculation of convolution. * \brief Backward input calculation of depthwise convolution.
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvGradInputFunction : public ConvFunctionBase { class DepthwiseConvGradInputFunction : public ConvFunctionBase {
...@@ -191,16 +188,14 @@ public: ...@@ -191,16 +188,14 @@ public:
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>(); real* inputGrad = outputs[0].data<real>();
size_t inputSize = batchSize * inputChannels * inputHeight * inputWidth;
DepthwiseConvGradInputFunctor<Device, real> depthwiseConvGradInput; DepthwiseConvGradInputFunctor<Device, real> depthwiseConvGradInput;
depthwiseConvGradInput(inputSize, depthwiseConvGradInput(outputGrad,
outputGrad,
filterData, filterData,
batchSize, batchSize,
outputChannels, outputChannels,
outputHeight, outputHeight,
outputWidth, outputWidth,
inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterHeight, filterHeight,
...@@ -214,7 +209,7 @@ public: ...@@ -214,7 +209,7 @@ public:
}; };
/* /*
* \brief Backward filter calculation of convolution. * \brief Backward filter calculation of depthwise convolution.
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvGradFilterFunction : public ConvFunctionBase { class DepthwiseConvGradFilterFunction : public ConvFunctionBase {
...@@ -255,35 +250,31 @@ public: ...@@ -255,35 +250,31 @@ public:
real* multiplierData = inputs[2].data<real>(); real* multiplierData = inputs[2].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
size_t size = int size =
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
resizeBuffer<Device>(size); resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
DepthwiseConvGradFilterFunctor<Device, real> depthwiseConvGradFilter; DepthwiseConvGradFilterFunctor<Device, real> depthwiseConvGradFilter;
for (size_t i = 0; i < batchSize; i++) { depthwiseConvGradFilter(outputGrad,
depthwiseConvGradFilter(i, inputData,
size, batchSize,
outputGrad, outputChannels,
inputData, outputHeight,
batchSize, outputWidth,
outputChannels, inputChannels,
outputHeight, inputHeight,
outputWidth, inputWidth,
inputHeight, filterHeight,
inputWidth, filterWidth,
filterHeight, strideH(),
filterWidth, strideW(),
strideH(), paddingH(),
strideW(), paddingW(),
paddingH(), colData,
paddingW(), multiplierData,
colData, filterGrad);
multiplierData,
filterGrad);
}
} }
}; };
......
...@@ -14,15 +14,36 @@ limitations under the License. */ ...@@ -14,15 +14,36 @@ limitations under the License. */
#pragma once #pragma once
#include "ConvOp.h" #include "TensorType.h"
namespace paddle { namespace paddle {
/**
*\brief Depthwise convolution forward. The outputData
* of depthwise convolution is same with ExpandConvLayer
* when groups equals inputChannels in ExpandConvLayer.
*
* \param[in] inputData input data.
* \param[in] filterData the Paramters of the depthwise conv layer..
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData..
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[out] outputData outputData.
*
*/
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvFunctor { class DepthwiseConvFunctor {
public: public:
void operator()(int outputSize, void operator()(const T* inputData,
const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
...@@ -39,16 +60,38 @@ public: ...@@ -39,16 +60,38 @@ public:
T* outputData); T* outputData);
}; };
/**
*\brief Functor tot compute the depthwise convolution backprop w.r.t input.
*
*
* \param[in] outputGradData the grad data of output.
* \param[in] filterData the Paramters of the depthwise conv layer..
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData..
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[out] inputGrad the grad data of input.
*
*/
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvGradInputFunctor { class DepthwiseConvGradInputFunctor {
public: public:
void operator()(int inputSize, void operator()(const T* outputGrad,
const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
...@@ -60,17 +103,42 @@ public: ...@@ -60,17 +103,42 @@ public:
T* inputGrad); T* inputGrad);
}; };
/**
*\brief Functor tot compute the depthwise convolution backprop w.r.t filter.
*
* \param[in] outputGradData the grad data of output.
* \param[in] inputData inputData.
* \param[in] batchSize batch size of input data.
* \param[in] outputChannels channels of outputData.
* \param[in] outputHeight height of outputData.
* \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData..
* \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction.
* \param[in] strideW stride size in width direction.
* \param[in] paddingH padding size in height direction.
* \param[in] paddingW padding size in width direction.
* \param[in] colData Auxiliary data when calculating filterGrad.
* size:
*inputChannels*filterHeight*filterWidth*outputHeight*outputWidth \param[in]
*multiplierData Auxiliary data when calculating filterGrad. size:
*outputHeight * outputWidth. \param[out]
*filterGrad the grad data of filter.
*
*/
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvGradFilterFunctor { class DepthwiseConvGradFilterFunctor {
public: public:
void operator()(int num_i, void operator()(const T* outputGrad,
int colDataSize,
const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
......
...@@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "ConvOp.h"
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "GemmFunctor.h" #include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {
// CUDA kernel to compute the depthwise convolution forward pass
template <class T> template <class T>
__global__ __global__
void ConvolutionDepthwiseForward(const int nthreads, void ConvolutionDepthwiseForward(const int nthreads,
...@@ -48,7 +47,7 @@ void ConvolutionDepthwiseForward(const int nthreads, ...@@ -48,7 +47,7 @@ void ConvolutionDepthwiseForward(const int nthreads,
for (int kw = 0; kw < filterWidth; ++kw) { for (int kw = 0; kw < filterWidth; ++kw) {
const int h_in = -paddingH + h * strideH + kh; const int h_in = -paddingH + h * strideH + kh;
const int w_in = -paddingW + w * strideW + kw; const int w_in = -paddingW + w * strideW + kw;
const int offset = ((n * outputChannels + c) * inputHeight + h_in) const int offset = ((n * outputChannels + c) * inputHeight + h_in)
* inputWidth + w_in; * inputWidth + w_in;
value += (*weight) * inputData[offset]; value += (*weight) * inputData[offset];
++weight; ++weight;
...@@ -73,6 +72,7 @@ void ConvolutionDepthwiseForward(const int nthreads, ...@@ -73,6 +72,7 @@ void ConvolutionDepthwiseForward(const int nthreads,
} }
} }
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
template <class T> template <class T>
__global__ __global__
void ConvolutionDepthwiseInputBackward(const int nthreads, void ConvolutionDepthwiseInputBackward(const int nthreads,
...@@ -113,6 +113,7 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, ...@@ -113,6 +113,7 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
} }
} }
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
template <class T> template <class T>
__global__ __global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
...@@ -150,15 +151,14 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, ...@@ -150,15 +151,14 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
template <class T> template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{ class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public: public:
void operator()(int outputSize, void operator()(const T* inputData,
const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -167,12 +167,14 @@ public: ...@@ -167,12 +167,14 @@ public:
int paddingW, int paddingW,
T* outputData){ T* outputData){
int outputSize = batchSize * outputChannels * outputHeight * outputWidth;
size_t blocks = (outputSize + 1024 -1) / 1024; size_t blocks = (outputSize + 1024 -1) / 1024;
size_t blockX = 512; size_t blockX = 512;
size_t blockY = (blocks+512-1)/512; size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseForward<T> ConvolutionDepthwiseForward<T>
<<< grid, threads, 0, STREAM_DEFAULT >>>( <<< grid, threads, 0, STREAM_DEFAULT >>>(
outputSize, outputSize,
...@@ -182,8 +184,8 @@ public: ...@@ -182,8 +184,8 @@ public:
outputChannels, outputChannels,
outputHeight, outputHeight,
outputWidth, outputWidth,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH, strideH,
...@@ -197,13 +199,13 @@ public: ...@@ -197,13 +199,13 @@ public:
template <class T> template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{ class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public: public:
void operator()(int inputSize, void operator()(const T* outputGrad,
const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
...@@ -212,7 +214,9 @@ public: ...@@ -212,7 +214,9 @@ public:
int strideW, int strideW,
int paddingH, int paddingH,
int paddingW, int paddingW,
T* inputGrad){ T* inputGrad){
int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
size_t blocks = (inputSize + 1024 -1) / 1024; size_t blocks = (inputSize + 1024 -1) / 1024;
size_t blockX = 512; size_t blockX = 512;
...@@ -220,6 +224,7 @@ public: ...@@ -220,6 +224,7 @@ public:
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseInputBackward<T> ConvolutionDepthwiseInputBackward<T>
// NOLINT_NEXT_LINE(whitespace/operators) // NOLINT_NEXT_LINE(whitespace/operators)
<<< grid, threads, 0, STREAM_DEFAULT >>>( <<< grid, threads, 0, STREAM_DEFAULT >>>(
...@@ -245,14 +250,13 @@ public: ...@@ -245,14 +250,13 @@ public:
template <class T> template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> { class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public: public:
void operator()(int num_i, void operator()(const T* outputGrad,
int colDataSize,
const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterHeight, int filterHeight,
...@@ -265,60 +269,65 @@ public: ...@@ -265,60 +269,65 @@ public:
T* multiplierData, T* multiplierData,
T* filterGrad){ T* filterGrad){
int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
size_t blocks = (colDataSize + 1024 -1) / 1024; size_t blocks = (colDataSize + 1024 -1) / 1024;
size_t blockX = 512; size_t blockX = 512;
size_t blockY = (blocks+512-1)/512; size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseFilterBackward<T> for(int i = 0; i < batchSize; i++) {
<<< grid, threads, 0, STREAM_DEFAULT >>>( ConvolutionDepthwiseFilterBackward<T>
num_i, <<< grid, threads, 0, STREAM_DEFAULT >>>(
colDataSize, i,
outputGrad, colDataSize,
inputData, outputGrad,
batchSize, inputData,
outputChannels, batchSize,
outputHeight, outputChannels,
outputWidth, outputHeight,
inputHeight, outputWidth,
inputWidth, inputHeight,
filterHeight, inputWidth,
filterWidth, filterHeight,
strideH, filterWidth,
strideW, strideH,
paddingH, strideW,
paddingW, paddingH,
colData paddingW,
); colData
GemmFunctor<DEVICE_TYPE_GPU, real> gemm; );
int M = colDataSize / outputHeight / outputWidth; GemmFunctor<DEVICE_TYPE_GPU, real> gemm;
int N = 1; int M = colDataSize / outputHeight / outputWidth;
int K = outputHeight * outputWidth; int N = 1;
gemm(CblasNoTrans, int K = outputHeight * outputWidth;
CblasNoTrans, gemm(CblasNoTrans,
M, CblasNoTrans,
N, M,
K, N,
(T)1.0, K,
colData, (T)1.0,
K, colData,
multiplierData, K,
N, multiplierData,
(T)1.0, N,
filterGrad, (T)1.0,
N); filterGrad,
N);
}
//gemv //gemv
} }
}; };
#ifdef PADDLE_TYPE_DOUBLE #ifdef PADDLE_TYPE_DOUBLE
using real=double; template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
#else #else
using real=float; template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
#endif #endif
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, real>;
} // namespace paddle } // namespace paddle
...@@ -15,14 +15,9 @@ limitations under the License. */ ...@@ -15,14 +15,9 @@ limitations under the License. */
#include "DepthwiseConvLayer.h" #include "DepthwiseConvLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include <iostream>
namespace paddle { namespace paddle {
/*
* The calculation of the exconvt(convolution transpose (deconv) operation)
* is a swap of forward and backward of the calculation of exconv.
* */
REGISTER_LAYER(depthwise_conv, DepthwiseConvLayer); REGISTER_LAYER(depthwise_conv, DepthwiseConvLayer);
bool DepthwiseConvLayer::init(const LayerMap &layerMap, bool DepthwiseConvLayer::init(const LayerMap &layerMap,
...@@ -76,11 +71,12 @@ bool DepthwiseConvLayer::init(const LayerMap &layerMap, ...@@ -76,11 +71,12 @@ bool DepthwiseConvLayer::init(const LayerMap &layerMap,
#define BACKWARD_FILTER(i, inputs, outputs) \ #define BACKWARD_FILTER(i, inputs, outputs) \
backward_[2 * i + 1]->calc(inputs, outputs) backward_[2 * i + 1]->calc(inputs, outputs)
// compute the depthwise convolution forward pass
void DepthwiseConvLayer::forward(PassType passType) { void DepthwiseConvLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight(); size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight();
// std::cout << "outputSize" << getOutputSize() <<std::endl;
resetOutput(batchSize, getOutputSize()); resetOutput(batchSize, getOutputSize());
// Calculate the shape of the input, output, and filter. // Calculate the shape of the input, output, and filter.
...@@ -127,6 +123,7 @@ void DepthwiseConvLayer::forward(PassType passType) { ...@@ -127,6 +123,7 @@ void DepthwiseConvLayer::forward(PassType passType) {
forwardActivation(); forwardActivation();
} }
// compute the depthwise convolution backprop.
void DepthwiseConvLayer::backward(const UpdateCallback &callback) { void DepthwiseConvLayer::backward(const UpdateCallback &callback) {
backwardActivation(); backwardActivation();
......
...@@ -22,10 +22,8 @@ namespace paddle { ...@@ -22,10 +22,8 @@ namespace paddle {
/** /**
* @brief A subclass of convolution layer. * @brief A subclass of convolution layer.
* This layer expands input and use matrix multiplication to * This layer do the depthwise convolution calculation in mobilenet.
* calculate convolution operation. * The config file api is img_depthwise_conv_layer.
*
* The config file api is img_conv_layer.
*/ */
class DepthwiseConvLayer : public ExpandConvBaseLayer { class DepthwiseConvLayer : public ExpandConvBaseLayer {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册