提交 84ded49d 编写于 作者: X xzl

fix comments

上级 6e17babe
...@@ -361,6 +361,9 @@ class DepthwiseConvKernel : public framework::OpKernel<T> { ...@@ -361,6 +361,9 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(
output->dims()[1] % input->dims()[1], 0,
"The output channels must be a multiple of the input channels");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
......
...@@ -203,8 +203,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> { ...@@ -203,8 +203,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides, const framework::Tensor& filter,
std::vector<int>& paddings, framework::Tensor* output) { const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
...@@ -244,7 +245,8 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> { ...@@ -244,7 +245,8 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& filter, const framework::Tensor& filter,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings, const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad) { framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
...@@ -284,7 +286,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> { ...@@ -284,7 +286,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings, const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad) { framework::Tensor* filter_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
......
...@@ -29,8 +29,9 @@ template <typename DeviceContext, typename T> ...@@ -29,8 +29,9 @@ template <typename DeviceContext, typename T>
class DepthwiseConvFunctor { class DepthwiseConvFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter, std::vector<int>& strides, const framework::Tensor& filter,
std::vector<int>& paddings, framework::Tensor* output); const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -39,7 +40,8 @@ class DepthwiseConvInputGradFunctor { ...@@ -39,7 +40,8 @@ class DepthwiseConvInputGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& filter, const framework::Tensor& filter,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings, const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad); framework::Tensor* input_grad);
}; };
...@@ -48,7 +50,8 @@ class DepthwiseConvFilterGradFunctor { ...@@ -48,7 +50,8 @@ class DepthwiseConvFilterGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
std::vector<int>& strides, std::vector<int>& paddings, const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* filter_grad); framework::Tensor* filter_grad);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册