提交 84ded49d 编写于 作者: X xzl

fix comments

上级 6e17babe
......@@ -361,6 +361,9 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(
output->dims()[1] % input->dims()[1], 0,
"The output channels must be a multiple of the input channels");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
......
......@@ -203,8 +203,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides,
std::vector<int>& paddings, framework::Tensor* output) {
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -244,7 +245,8 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& input,
const framework::Tensor& filter,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
......@@ -284,7 +286,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
......
......@@ -29,8 +29,9 @@ template <typename DeviceContext, typename T>
class DepthwiseConvFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides,
std::vector<int>& paddings, framework::Tensor* output);
const framework::Tensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
......@@ -39,7 +40,8 @@ class DepthwiseConvInputGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad);
};
......@@ -48,7 +50,8 @@ class DepthwiseConvFilterGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册