提交 de3188f9 编写于 作者: H hong19860320 提交者: GitHub

enable reshape, reshape2 and scale op (#17802)

* enable reshape&reshape2 op and add unit test

* enable scale op and add unit test

* fix XShape checking for reshape2 unit test
test=develop

* remove op_desc.hasInput(..) and op_desc.hasOutput(..) to adapt X86 platform
test=develop

* remove target() and precision() from softmax, reshape, reshape2 and scale op
fix CopyDataFrom() of TensorHvy
test=develop

* alloc memory then copy tensor's data during invoking TensorHvy::CopyDataFrom(...)
test=develop
上级 95bc0ce7
...@@ -6,4 +6,4 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) ...@@ -6,4 +6,4 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return() return()
endif() endif()
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc DEPS ${lite_kernel_deps} eigen3) cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/scale.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void scale<float>(const float* din, float* dout, int num, float scale,
float bias) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
din_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void scale(const T* din, T* dout, int num, float scale, float bias);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -102,7 +102,8 @@ class TensorHvy : public TensorBase<TensorHvy> { ...@@ -102,7 +102,8 @@ class TensorHvy : public TensorBase<TensorHvy> {
data_.ShareDataWith(other.data_); data_.ShareDataWith(other.data_);
} }
void CopyDataFrom(const TensorHvy& other) { void CopyDataFrom(const TensorHvy& other) {
data_.ShareDataWith(other.data_); data_.mutable_data(other.data_.place(), other.data_.type());
TensorCopySync(other.data_, data_.place(), &data_);
} }
DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); } DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); }
......
...@@ -37,7 +37,7 @@ class DDimLite : public DDimBase<DDimLite> { ...@@ -37,7 +37,7 @@ class DDimLite : public DDimBase<DDimLite> {
void ConstructFrom(const std::vector<value_type> &x) { data_ = x; } void ConstructFrom(const std::vector<value_type> &x) { data_ = x; }
value_type operator[](int offset) const { return data_[offset]; } value_type operator[](int offset) const { return data_[offset]; }
std::vector<int64_t> Vectorize() { return data_; } std::vector<int64_t> Vectorize() const { return data_; }
size_t size() const { return data_.size(); } size_t size() const { return data_.size(); }
bool empty() const { return data_.empty(); } bool empty() const { return data_.empty(); }
......
...@@ -48,7 +48,7 @@ class DDimBase { ...@@ -48,7 +48,7 @@ class DDimBase {
explicit DDimBase(const std::vector<int64_t> &x) { self()->ConstructFrom(x); } explicit DDimBase(const std::vector<int64_t> &x) { self()->ConstructFrom(x); }
value_type operator[](int offset) const { return (*self())[offset]; } value_type operator[](int offset) const { return (*self())[offset]; }
std::vector<int64_t> Vectorize() { return self()->Vectorize(); } std::vector<int64_t> Vectorize() const { return self()->Vectorize(); }
size_t size() const { return const_self()->size(); } size_t size() const { return const_self()->size(); }
bool empty() const { return const_self()->empty(); } bool empty() const { return const_self()->empty(); }
......
...@@ -7,10 +7,11 @@ message(STATUS "compile with lite ARM kernels") ...@@ -7,10 +7,11 @@ message(STATUS "compile with lite ARM kernels")
cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
set(arm_kernels set(arm_kernels
......
...@@ -12,38 +12,27 @@ ...@@ -12,38 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <Eigen/Core> #include "paddle/fluid/lite/kernels/arm/scale_compute.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <typename T> void ScaleCompute::Run() {
void scale_compute(const T* x, T* out, int size, float scale, float bias,
bool bias_before) {
if (bias_before) bias *= scale;
for (int i = 0; i < size; i++) {
out[i] = x[i] * scale + bias;
}
}
class ScaleCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MulParam;
void Run() override {
auto& param = Param<operators::ScaleParam>(); auto& param = Param<operators::ScaleParam>();
scale_compute(param.x->data<float>(), param.output->mutable_data<float>(), const float* x_data = param.x->data<float>();
param.x->dims().production(), param.scale, param.bias, float* output_data = param.output->mutable_data<float>();
param.bias_after_scale); DDim x_dims = param.x->dims();
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
float bias = param.bias;
if (!bias_after_scale) {
bias *= scale;
} }
lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias);
virtual ~ScaleCompute() = default; }
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ScaleCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ScaleCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/scale_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void scale_compute_ref(const operators::ScaleParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim x_dims = param.x->dims();
DDim output_dims = param.output->dims();
ASSERT_EQ(x_dims.data(), output_dims.data());
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
float bias = param.bias;
if (!bias_after_scale) {
bias *= scale;
}
for (int i = 0; i < output_dims.production(); i++) {
output_data[i] = x_data[i] * scale + bias;
}
}
TEST(scale_arm, init) {
ScaleCompute scale;
ASSERT_EQ(scale.precision(), PRECISION(kFloat));
ASSERT_EQ(scale.target(), TARGET(kARM));
}
TEST(scale_arm, compute) {
ScaleCompute scale;
operators::ScaleParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto bias_after_scale : {true, false}) {
for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) {
for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) {
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
param.x = &x;
param.output = &output;
param.bias_after_scale = bias_after_scale;
param.scale = s;
param.bias = b;
scale.SetParam(param);
scale.Run();
param.output = &output_ref;
scale_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
}
}
}
}
}
}
}
TEST(scale, retrive_op) {
auto scale =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("scale");
ASSERT_FALSE(scale.empty());
ASSERT_TRUE(scale.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
...@@ -24,15 +24,15 @@ void SoftmaxCompute::Run() { ...@@ -24,15 +24,15 @@ void SoftmaxCompute::Run() {
auto& param = Param<operators::SoftmaxParam>(); auto& param = Param<operators::SoftmaxParam>();
const float* din = param.x->data<float>(); const float* din = param.x->data<float>();
float* dout = param.output->mutable_data<float>(); float* dout = param.output->mutable_data<float>();
auto dim_x = param.x->dims(); auto x_dims = param.x->dims();
auto rank_x = dim_x.size(); auto x_rank = x_dims.size();
int axis = param.axis; int axis = param.axis;
if (axis < 0) { if (axis < 0) {
axis += rank_x; axis += x_rank;
} }
int outer_num = dim_x.Slice(0, axis).production(); int outer_num = x_dims.Slice(0, axis).production();
int inner_num = dim_x.Slice(axis + 1, rank_x).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int axis_size = dim_x[axis]; int axis_size = x_dims[axis];
if (inner_num == 1) { if (inner_num == 1) {
if (axis_size >= 4) { if (axis_size >= 4) {
lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num, lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num,
...@@ -64,10 +64,6 @@ void SoftmaxCompute::Run() { ...@@ -64,10 +64,6 @@ void SoftmaxCompute::Run() {
} }
} }
TargetType SoftmaxCompute::target() const { return TARGET(kARM); }
PrecisionType SoftmaxCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -26,9 +26,6 @@ class SoftmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -26,9 +26,6 @@ class SoftmaxCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
void Run() override; void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~SoftmaxCompute() = default; virtual ~SoftmaxCompute() = default;
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" #include "paddle/fluid/lite/kernels/arm/softmax_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <limits>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -23,19 +24,19 @@ namespace kernels { ...@@ -23,19 +24,19 @@ namespace kernels {
namespace arm { namespace arm {
template <typename dtype> template <typename dtype>
void softmat_compute_ref(const operators::SoftmaxParam& param) { void softmax_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>(); const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>(); dtype* output_data = param.output->mutable_data<dtype>();
DDim dim = param.x->dims(); DDim x_dims = param.x->dims();
ASSERT_EQ(dim.data(), param.output->dims().data()); ASSERT_EQ(x_dims.data(), param.output->dims().data());
auto rank = dim.size(); auto x_rank = x_dims.size();
int axis = param.axis; int axis = param.axis;
if (axis < 0) { if (axis < 0) {
axis += rank; axis += x_rank;
} }
int axis_size = dim[axis]; int axis_size = x_dims[axis];
int outer_num = dim.Slice(0, axis).production(); int outer_num = x_dims.Slice(0, axis).production();
int inner_num = dim.Slice(axis + 1, rank).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int compute_size = outer_num * inner_num; int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) { for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num; int idx_inner = i % inner_num;
...@@ -100,7 +101,7 @@ TEST(softmax_arm, compute) { ...@@ -100,7 +101,7 @@ TEST(softmax_arm, compute) {
softmax.SetParam(param); softmax.SetParam(param);
softmax.Run(); softmax.Run();
param.output = &output_ref; param.output = &output_ref;
softmat_compute_ref<float>(param); softmax_compute_ref<float>(param);
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
} }
......
...@@ -2,10 +2,14 @@ message(STATUS "compile with lite host kernels") ...@@ -2,10 +2,14 @@ message(STATUS "compile with lite host kernels")
cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps})
cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
cc_library(reshape_compute_host SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op_lite)
lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host)
set(host_kernels set(host_kernels
feed_compute_host feed_compute_host
fetch_compute_host fetch_compute_host
reshape_compute_host
) )
set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels") set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/host/reshape_compute.h"
#include <vector>
#include "paddle/fluid/lite/operators/reshape_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void ReshapeCompute::Run() {
auto& param = Param<operators::ReshapeParam>();
auto x = param.x;
auto actual_shape = param.actual_shape;
auto output = param.output;
bool inplace = param.inplace;
auto x_dims = x->dims();
auto output_dims = output->dims();
if (actual_shape) {
auto actual_shape_dims = actual_shape->dims();
auto* actual_shape_data = actual_shape->data<int>();
#ifdef LITE_WITH_CUDA
lite::Tensor cpu_actual_shape;
if (actual_shape->target() == TARGET(kCUDA)) {
cpu_actual_shape.CopyDataFrom(*actual_shape);
actual_shape_data = cpu_actual_shape.data<int>();
}
#endif
auto shape = std::vector<int>(
actual_shape_data, actual_shape_data + actual_shape_dims.production());
output_dims = lite::operators::ValidateShape(shape, x_dims);
output->Resize(output_dims);
}
if (inplace) {
output->ShareDataWith(*x);
} else {
output->CopyDataFrom(*x);
}
output->Resize(output_dims);
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(reshape, kHost, kAny, kAny,
paddle::lite::kernels::host::ReshapeCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.Finalize();
REGISTER_LITE_KERNEL(reshape2, kHost, kAny, kAny,
paddle::lite::kernels::host::ReshapeCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
DATALAYOUT(kAny), -1)})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class ReshapeCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~ReshapeCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/host/reshape_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
TEST(reshape_host, init) {
ReshapeCompute reshape;
ASSERT_EQ(reshape.precision(), PRECISION(kAny));
ASSERT_EQ(reshape.target(), TARGET(kHost));
}
TEST(reshape_host, compute) {
ReshapeCompute reshape;
operators::ReshapeParam param;
Tensor x;
Tensor actual_shape;
Tensor output;
x.Resize(DDim(std::vector<int64_t>({1, 2, 4, 6})));
actual_shape.Resize(DDim(std::vector<int64_t>({2})));
auto* x_data = x.mutable_data<float>();
auto* actual_shape_data = actual_shape.mutable_data<int>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
actual_shape_data[0] = 6;
actual_shape_data[1] = 8;
param.x = &x;
param.shape = {-1, 0, 3, 2, 1};
param.output = &output;
param.actual_shape = &actual_shape;
param.inplace = false;
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(actual_shape.dims().production(), output.dims().size());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], actual_shape_data[i]);
}
// check output data
auto* output_data = output.mutable_data<float>();
CHECK_NE(output_data, x_data);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], x_data[i], 1e-6);
}
// check output data if inplace = true;
param.inplace = true;
reshape.SetParam(param);
reshape.Run();
output_data = output.mutable_data<float>();
CHECK_EQ(output_data, x_data);
}
TEST(reshape, retrive_op) {
auto reshape =
KernelRegistry::Global()
.Create<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)>("reshape");
ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front());
}
TEST(reshape2, retrive_op) {
auto reshape2 =
KernelRegistry::Global()
.Create<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)>("reshape2");
ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front());
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
...@@ -17,3 +17,5 @@ ...@@ -17,3 +17,5 @@
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
...@@ -38,6 +38,29 @@ void OpDesc::SetAttr<std::string>(const std::string &name, ...@@ -38,6 +38,29 @@ void OpDesc::SetAttr<std::string>(const std::string &name,
it->set_s(v.c_str()); it->set_s(v.c_str());
} }
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v) {
auto &xs = *desc_.mutable_attrs();
auto it = std::find_if(
xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; });
if (it == xs.end()) {
auto *attr = xs.Add();
attr->set_name(name);
it = std::find_if(xs.begin(), xs.end(),
[&](const framework::proto::OpDesc_Attr &x) {
return x.name() == name;
});
}
it->set_type(framework::proto::INTS);
it->clear_ints();
for (auto &i : v) {
it->add_ints(i);
}
}
} // namespace pb } // namespace pb
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -33,7 +33,8 @@ namespace paddle { ...@@ -33,7 +33,8 @@ namespace paddle {
namespace lite { namespace lite {
namespace pb { namespace pb {
using Attribute = variant<int, float, bool, std::vector<std::string>>; using Attribute =
variant<int, float, bool, std::vector<std::string>, std::vector<int>>;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
/* /*
...@@ -152,7 +153,6 @@ class OpDesc { ...@@ -152,7 +153,6 @@ class OpDesc {
Attribute res; Attribute res;
CHECK(it != xs.end()); CHECK(it != xs.end());
switch (it->type()) { switch (it->type()) {
case framework::proto::INT: case framework::proto::INT:
res.set<int>(it->i()); res.set<int>(it->i());
...@@ -166,6 +166,13 @@ class OpDesc { ...@@ -166,6 +166,13 @@ class OpDesc {
case framework::proto::BOOLEAN: case framework::proto::BOOLEAN:
res.set<bool>(it->b()); res.set<bool>(it->b());
break; break;
case framework::proto::INTS: {
std::vector<int> values;
const auto &ys = it->ints();
std::transform(ys.begin(), ys.end(), std::back_inserter(values),
[](const int &x) { return x; });
res.set<std::vector<int>>(values);
} break;
default: default:
LOG(FATAL) << "unsupported attr type"; LOG(FATAL) << "unsupported attr type";
...@@ -231,6 +238,10 @@ template <> ...@@ -231,6 +238,10 @@ template <>
void OpDesc::SetAttr<std::string>(const std::string &name, void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v); const std::string &v);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v);
} // namespace pb } // namespace pb
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -5,6 +5,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) ...@@ -5,6 +5,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS})
cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} )
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
...@@ -22,6 +23,7 @@ set(ops_lite ...@@ -22,6 +23,7 @@ set(ops_lite
mul_op_lite mul_op_lite
scale_op_lite scale_op_lite
softmax_op_lite softmax_op_lite
reshape_op_lite
feed_op_lite feed_op_lite
fetch_op_lite fetch_op_lite
io_copy_op_lite io_copy_op_lite
...@@ -36,4 +38,6 @@ lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc ...@@ -36,4 +38,6 @@ lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86 X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm) ARM_DEPS fc_compute_arm)
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
...@@ -102,6 +102,17 @@ struct SoftmaxParam { ...@@ -102,6 +102,17 @@ struct SoftmaxParam {
int axis{-1}; int axis{-1};
}; };
// For Reshape and Reshape2 Op
struct ReshapeParam {
const lite::Tensor* x{};
const lite::Tensor* actual_shape{nullptr};
lite::Tensor* output{};
lite::Tensor* xshape{};
std::vector<int> shape{};
bool inplace{false};
};
// For Convolution op // For Convolution op
struct ConvParam { struct ConvParam {
lite::Tensor* x{}; lite::Tensor* x{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/reshape_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReshapeOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(!param_.shape.empty());
return true;
}
bool ReshapeOp::InferShape() const {
auto x_dims = param_.x->dims();
auto output_dims = ValidateShape(param_.shape, x_dims);
param_.output->Resize(output_dims);
return true;
}
bool ReshapeOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.x = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.output = output_var->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") !=
input_arg_names.end()) {
auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front());
if (actual_shape_var != nullptr) {
param_.actual_shape =
const_cast<lite::Tensor *>(&(actual_shape_var->Get<lite::Tensor>()));
}
}
param_.shape = GetAttr<std::vector<int>>(opdesc.GetAttr("shape"));
if (opdesc.HasAttr("inplace")) {
param_.inplace = GetAttr<bool>(opdesc.GetAttr("inplace"));
}
CHECK(param_.x) << "Input(X) of ReshapeOp should not be null.";
CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null.";
CHECK(!param_.shape.empty())
<< "The shape information must be set by Attr(shape).";
return true;
}
bool Reshape2Op::CheckShape() const {
ReshapeOp::CheckShape();
CHECK_OR_FALSE(param_.xshape);
return true;
}
bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape();
auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (int i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.xshape->Resize(DDim(xshape_dims));
return true;
}
bool Reshape2Op::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) {
ReshapeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.xshape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null.";
return true;
}
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
const DDim::value_type input_size = input_dims.production();
auto input_shape = input_dims.Vectorize();
bool all_positive = std::all_of(input_shape.cbegin(), input_shape.cend(),
[](DDim::value_type i) { return i > 0; });
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int unk_dim_val = -1;
const int copy_dim_val = 0;
std::vector<DDim::value_type> output_shape(shape.size(), 0);
DDim::value_type capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
CHECK_EQ(unk_dim_idx, -1)
<< "Only one input dimension of Attr(shape) can be unknown.";
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
CHECK_LT(static_cast<int>(i), input_shape.size())
<< "The index of dimension to copy from input shape must be less "
"than the size of input shape.";
} else {
CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not "
"be negtive except one unknown dimension.";
}
capacity *=
(shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_shape[i]);
output_shape[i] =
(shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_shape[i]);
}
if (unk_dim_idx != -1) {
if (all_positive) {
// input_size < 0 and is un-determinate in compile time, skip the check,
// for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, input_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape[unk_dim_idx] = -input_size / capacity;
CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size)
<< "Invalid shape is given.";
} else {
output_shape[unk_dim_idx] = -1;
}
} else {
CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
}
return DDim(output_shape);
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp);
REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ReshapeOp : public OpLite {
public:
ReshapeOp() {}
explicit ReshapeOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reshape"; }
protected:
mutable ReshapeParam param_;
};
class Reshape2Op : public ReshapeOp {
public:
Reshape2Op() : ReshapeOp() {}
explicit Reshape2Op(const std::string &op_type) : ReshapeOp(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reshape2"; }
};
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims);
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/reshape_op.h"
#include <gtest/gtest.h>
#include <map>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(reshape_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* actual_shape = scope.Var("actual_shape")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
std::map<std::vector<int>, std::vector<int64_t>> shapes = {
{{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}},
{{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}},
{{-1, 48}, {1, 48}},
{{48, -1}, {48, 1}},
{{0, 24}, {2, 24}},
{{12, 0}, {12, 4}},
};
x->Resize(DDim(std::vector<int64_t>({2, 4, 6})));
actual_shape->Resize(DDim(std::vector<int64_t>({2})));
auto* actual_shape_data = actual_shape->mutable_data<int>();
actual_shape_data[0] = 6;
actual_shape_data[1] = 8;
for (auto& shape : shapes) {
for (auto& has_actual_shape : {true, false}) {
for (auto& inplace : {true, false}) {
// prepare op desc
lite::OpDesc desc;
desc.SetType("reshape");
desc.SetInput("X", {"x"});
if (has_actual_shape) {
desc.SetInput("Shape", {"actual_shape"});
}
desc.SetOutput("Out", {"output"});
desc.SetAttr("shape", shape.first);
desc.SetAttr("inplace", inplace);
ReshapeOp reshape("reshape");
reshape.SetValidPlaces(
{Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}});
reshape.Attach(desc, &scope);
reshape.CheckShape();
reshape.InferShape();
// check output dims
auto output_dims = output->dims();
CHECK_EQ(output_dims.size(), shape.second.size());
for (int i = 0; i < output_dims.size(); i++) {
CHECK_EQ(output_dims[i], shape.second[i]);
}
}
}
}
}
TEST(reshape2_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* actual_shape = scope.Var("actual_shape")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
auto* xshape = scope.Var("xshape")->GetMutable<Tensor>();
std::map<std::vector<int>, std::vector<int64_t>> shapes = {
{{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}},
{{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}},
{{-1, 48}, {1, 48}},
{{48, -1}, {48, 1}},
{{0, 24}, {2, 24}},
{{12, 0}, {12, 4}},
};
x->Resize(DDim(std::vector<int64_t>({2, 4, 6})));
actual_shape->Resize(DDim(std::vector<int64_t>({2})));
auto* actual_shape_data = actual_shape->mutable_data<int>();
actual_shape_data[0] = 6;
actual_shape_data[1] = 8;
for (auto& shape : shapes) {
for (auto& has_actual_shape : {true, false}) {
for (auto& inplace : {true, false}) {
// prepare op desc
lite::OpDesc desc;
desc.SetType("reshape");
desc.SetInput("X", {"x"});
if (has_actual_shape) {
desc.SetInput("Shape", {"actual_shape"});
}
desc.SetOutput("Out", {"output"});
desc.SetOutput("XShape", {"xshape"});
desc.SetAttr("shape", shape.first);
desc.SetAttr("inplace", inplace);
Reshape2Op reshape2("reshape2");
reshape2.SetValidPlaces(
{Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}});
reshape2.Attach(desc, &scope);
reshape2.CheckShape();
reshape2.InferShape();
// check output dims
auto output_dims = output->dims();
CHECK_EQ(output_dims.size(), shape.second.size());
for (int i = 0; i < output_dims.size(); i++) {
CHECK_EQ(output_dims[i], shape.second[i]);
}
// check xshape dims
auto x_dims = x->dims();
auto xshape_dims = xshape->dims();
CHECK_EQ(xshape_dims.size(), x_dims.size() + 1);
CHECK_EQ(xshape_dims[0], 0);
for (int i = 0; i < x_dims.size(); i++) {
CHECK_EQ(xshape_dims[i + 1], x_dims[i]);
}
}
}
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -12,58 +12,35 @@ ...@@ -12,58 +12,35 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <string> #include "paddle/fluid/lite/operators/scale_op.h"
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class ScaleOp : public OpLite { bool ScaleOp::CheckShape() const {
public:
ScaleOp() {}
explicit ScaleOp(const std::string &type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
return true; return true;
} }
bool InferShape() const override { bool ScaleOp::InferShape() const {
param_.output->Resize(param_.x->dims()); param_.output->Resize(param_.x->dims());
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. bool ScaleOp::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto output = op_desc.Output("Out").front();
param_.x = scope->FindVar(x)->GetMutable<Tensor>(); param_.x = scope->FindVar(x)->GetMutable<Tensor>();
CHECK(scope->FindVar(out)); param_.output = scope->FindVar(output)->GetMutable<Tensor>();
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.scale = GetAttr<float>(op_desc.GetAttr("scale")); param_.scale = GetAttr<float>(op_desc.GetAttr("scale"));
param_.bias = GetAttr<float>(op_desc.GetAttr("bias")); param_.bias = GetAttr<float>(op_desc.GetAttr("bias"));
param_.bias_after_scale = param_.bias_after_scale = GetAttr<bool>(op_desc.GetAttr("bias_after_scale"));
GetAttr<bool>(op_desc.GetAttr("bias_after_scale")); CHECK(param_.x);
CHECK(param_.output);
return true; return true;
} }
std::string DebugString() const override { return op_type_; }
private:
mutable ScaleParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ScaleOp : public OpLite {
public:
ScaleOp() {}
explicit ScaleOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "scale"; }
private:
mutable ScaleParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/scale_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(scale_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({10, 20})));
output->Resize(DDim(std::vector<int64_t>{1, 1}));
// prepare op desc
lite::OpDesc desc;
desc.SetType("scale");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("bias_after_scale", false);
desc.SetAttr("scale", 0.5f);
desc.SetAttr("bias", 0.125f);
ScaleOp scale("scale");
scale.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
scale.Attach(desc, &scope);
scale.CheckShape();
scale.InferShape();
// check output dims
auto x_dims = x->dims();
auto output_dims = output->dims();
CHECK_EQ(output_dims.size(), x_dims.size());
for (int i = 0; i < output_dims.size(); i++) {
CHECK_EQ(output_dims[i], x_dims[i]);
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -22,9 +22,9 @@ namespace operators { ...@@ -22,9 +22,9 @@ namespace operators {
bool SoftmaxOp::CheckShape() const { bool SoftmaxOp::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
auto dim_x = param_.x->dims(); auto x_dims = param_.x->dims();
auto rank_x = dim_x.size(); auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -rank_x && param_.axis < rank_x); CHECK_OR_FALSE(param_.axis >= -x_rank && param_.axis < x_rank);
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册