未验证 提交 af024a42 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #404 from cocodark/develop

add conv_add op kernel
...@@ -126,7 +126,7 @@ else () ...@@ -126,7 +126,7 @@ else ()
add_definitions(-DCONV_OP) add_definitions(-DCONV_OP)
add_definitions(-DDEPTHWISECONV_OP) add_definitions(-DDEPTHWISECONV_OP)
add_definitions(-DELEMENTWISEADD_OP) add_definitions(-DELEMENTWISEADD_OP)
add_definitions(-DFUSIONCONVADD_OP) add_definitions(-DFUSION_CONVADD_OP)
add_definitions(-DCONVADDRELU_OP) add_definitions(-DCONVADDRELU_OP)
add_definitions(-DFUSION_FC_OP) add_definitions(-DFUSION_FC_OP)
add_definitions(-DLRN_OP) add_definitions(-DLRN_OP)
......
...@@ -131,7 +131,6 @@ class Tensor { ...@@ -131,7 +131,6 @@ class Tensor {
} }
PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor'snumel must >=0.") PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor'snumel must >=0.")
int64_t size = numel() * SizeOfType(type); int64_t size = numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || holder_->size() < size + offset_) { if (holder_ == nullptr || holder_->size() < size + offset_) {
holder_.reset(new PlaceholderImpl(size, type)); holder_.reset(new PlaceholderImpl(size, type));
offset_ = 0; offset_ = 0;
......
...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define FUSION_CONVADD_OP
#ifdef FUSION_CONVADD_OP #ifdef FUSION_CONVADD_OP
#pragma once #pragma once
......
...@@ -18,6 +18,27 @@ limitations under the License. */ ...@@ -18,6 +18,27 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void expand_bias(Tensor &bias, int axis, const DDim &dDim) {
auto bias_ptr = bias.data<float>();
const DDim bias_ddim = bias.dims();
PADDLE_MOBILE_ENFORCE(bias.dims().size() == 1,
"the bias tensor's dims size != 1")
DDim outer_ddim = paddle_mobile::framework::slice_ddim(dDim, 0, axis + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis + 1, dDim.size());
int outer_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
bias.Resize(dDim);
auto new_ptr = bias.mutable_data<float>();
int axis_size = dDim[axis];
for (int i = 0; i < outer_size; ++i) {
float v_bias = bias_ptr[i * axis_size / outer_size];
for (int j = 0; j < inner_size; ++j) {
new_ptr[i * inner_size + j] = v_bias;
}
}
}
template <> template <>
void ConvAddKernel<CPU, float>::Compute( void ConvAddKernel<CPU, float>::Compute(
const FushionConvAddParam &param) const { const FushionConvAddParam &param) const {
...@@ -25,15 +46,16 @@ void ConvAddKernel<CPU, float>::Compute( ...@@ -25,15 +46,16 @@ void ConvAddKernel<CPU, float>::Compute(
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>(); expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations(); std::vector<int> dilations = param.Dilations();
// DLOG << " compute end get Attrs " << strides[0];
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
...@@ -66,7 +88,6 @@ void ConvAddKernel<CPU, float>::Compute( ...@@ -66,7 +88,6 @@ void ConvAddKernel<CPU, float>::Compute(
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
DLOG << " filter.dims() = " << filter.dims();
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
output->dims()[1], output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])}; output->numel() / (output->dims()[0] * output->dims()[1])};
...@@ -105,7 +126,7 @@ void ConvAddKernel<CPU, float>::Compute( ...@@ -105,7 +126,7 @@ void ConvAddKernel<CPU, float>::Compute(
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(0)); static_cast<float>(1));
} }
} }
} }
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "framework/ddim.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
...@@ -26,6 +27,7 @@ limitations under the License. */ ...@@ -26,6 +27,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using framework::DDim;
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册