提交 7172851d 编写于 作者: W wangliu

modify code style

上级 88078a28
...@@ -18,10 +18,10 @@ limitations under the License. */ ...@@ -18,10 +18,10 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "operators/kernel/conv_add_kernel.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h" #include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h" #include "op_param.h"
#include "operators/kernel/conv_add_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -19,7 +19,8 @@ namespace paddle_mobile { ...@@ -19,7 +19,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FushionConvAddParam &param) const { void ConvAddKernel<CPU, float>::Compute(
const FushionConvAddParam &param) const {
DLOG << param; DLOG << param;
const Tensor *input = param.Input(); const Tensor *input = param.Input();
...@@ -48,7 +49,7 @@ void ConvAddKernel<CPU, float>::Compute(const FushionConvAddParam &param) const ...@@ -48,7 +49,7 @@ void ConvAddKernel<CPU, float>::Compute(const FushionConvAddParam &param) const
framework::DDim col_shape(framework::make_ddim(col_shape_vec)); framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape = framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1); framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col; Tensor col;
...@@ -60,15 +61,15 @@ void ConvAddKernel<CPU, float>::Compute(const FushionConvAddParam &param) const ...@@ -60,15 +61,15 @@ void ConvAddKernel<CPU, float>::Compute(const FushionConvAddParam &param) const
} }
framework::DDim input_shape = framework::slice_ddim( framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size())); input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
DLOG << " filter.dims() = " << filter.dims(); DLOG << " filter.dims() = " << filter.dims();
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output->dims()[1], output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])}; output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm // convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
......
...@@ -30,7 +30,7 @@ using framework::OpKernelBase; ...@@ -30,7 +30,7 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FushionConvAddParam> { class ConvAddKernel : public OpKernelBase<DeviceType, FushionConvAddParam> {
public: public:
void Compute(const FushionConvAddParam &param) const; void Compute(const FushionConvAddParam &param) const;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册