未验证 提交 afefe9cf 编写于 作者: Z zhupengyang 提交者: GitHub

assign support tensor_array (#3436)

上级 2168ff38
...@@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k ...@@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k
add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 3. extra kernels ## 3. extra kernels
add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
...@@ -12,3 +12,4 @@ add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_k ...@@ -12,3 +12,4 @@ add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_k
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps}) add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps}) add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps}) add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
...@@ -12,29 +12,42 @@ ...@@ -12,29 +12,42 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/assign_compute.h" #include "lite/kernels/host/assign_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void AssignCompute::Run() { void AssignCompute::Run() {
auto& param = Param<param_t>(); auto& param = Param<param_t>();
param.Out->CopyDataFrom(*param.X); if (param.X != nullptr) {
param.Out->CopyDataFrom(*param.X);
} else if (param.X_array != nullptr) {
auto x_array = param.X_array;
auto out_array = param.Out_array;
out_array->resize(x_array->size());
for (size_t i = 0; i < x_array->size(); i++) {
out_array->at(i).CopyDataFrom(x_array->at(i));
}
} else {
LOG(FATAL) << "x or x_array of assign must be set.";
}
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
assign, kARM, kAny, kNCHW, paddle::lite::kernels::arm::AssignCompute, def) assign, kHost, kAny, kAny, paddle::lite::kernels::host::AssignCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/operators/assign_op.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class AssignCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::AssignParam; using param_t = operators::AssignParam;
...@@ -31,7 +32,7 @@ class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -31,7 +32,7 @@ class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
virtual ~AssignCompute() = default; virtual ~AssignCompute() = default;
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -27,20 +27,33 @@ bool AssignOpLite::CheckShape() const { ...@@ -27,20 +27,33 @@ bool AssignOpLite::CheckShape() const {
} }
bool AssignOpLite::InferShapeImpl() const { bool AssignOpLite::InferShapeImpl() const {
lite::DDim input_dims; if (param_.X != nullptr) {
input_dims = param_.X->dims(); param_.Out->Resize(param_.X->dims());
param_.Out->Resize(lite::DDim(input_dims)); } else if (param_.X_array != nullptr) {
param_.Out_array->resize(param_.Out_array->size());
} else {
LOG(FATAL) << "x or x_array must be set.";
}
return true; return true;
} }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AssignOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool AssignOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto input = op_desc.Input("X").front(); auto x_name = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out_name = op_desc.Output("Out").front();
param_.X = scope->FindVar(input)->GetMutable<lite::Tensor>(); auto x_var = scope->FindVar(x_name);
CHECK(scope->FindVar(out)); if (x_var->IsType<Tensor>()) {
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.X = scope->FindTensor(x_name);
param_.Out = scope->FindMutableTensor(out_name);
} else if (x_var->IsType<std::vector<Tensor>>()) {
param_.X_array = x_var->GetMutable<std::vector<Tensor>>();
param_.Out_array =
scope->FindVar(out_name)->GetMutable<std::vector<Tensor>>();
} else {
LOG(FATAL) << "X type for assign op is unsupported. Expected type is "
"tensor or tensor_array.";
}
return true; return true;
} }
......
...@@ -1279,8 +1279,13 @@ struct GatherParam : ParamBase { ...@@ -1279,8 +1279,13 @@ struct GatherParam : ParamBase {
/// ----------------------- assign operators ----------------------- /// ----------------------- assign operators -----------------------
struct AssignParam : ParamBase { struct AssignParam : ParamBase {
const lite::Tensor* X{}; // for tensor
lite::Tensor* Out{}; const lite::Tensor* X{nullptr};
lite::Tensor* Out{nullptr};
// for tensor_array
const std::vector<lite::Tensor>* X_array{nullptr};
std::vector<lite::Tensor>* Out_array{nullptr};
}; };
/// ----------------------- roi_align operators ----------------------- /// ----------------------- roi_align operators -----------------------
......
...@@ -69,7 +69,7 @@ void TestAssign(const Place& place) { ...@@ -69,7 +69,7 @@ void TestAssign(const Place& place) {
TEST(Assign, precision) { TEST(Assign, precision) {
Place place; Place place;
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
place = TARGET(kARM); place = TARGET(kHost);
#else #else
return; return;
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册