未验证 提交 3a04e11d 编写于 作者: Z zhupengyang 提交者: GitHub

move read_from_array and write_to_array to host (#3428)

上级 b445941f
...@@ -92,8 +92,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute. ...@@ -92,8 +92,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
...@@ -114,14 +114,14 @@ struct BeamSearchDecoder { ...@@ -114,14 +114,14 @@ struct BeamSearchDecoder {
lod.push_back(source_level_lod); lod.push_back(source_level_lod);
lod.push_back(sentence_level_lod); lod.push_back(sentence_level_lod);
*(id_tensor->mutable_lod()) = lod; id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); id_tensor->Resize({static_cast<int64_t>(id_data.size())});
auto id_ptr = id_tensor->mutable_data<int64_t>(); auto id_ptr = id_tensor->mutable_data<int64_t>();
TargetCopy( TargetCopy(
TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t)); TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t));
*(score_tensor->mutable_lod()) = lod; score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())}); score_tensor->Resize({static_cast<int64_t>(score_data.size())});
auto score_ptr = score_tensor->mutable_data<T>(); auto score_ptr = score_tensor->mutable_data<T>();
TargetCopy(TARGET(kARM), TargetCopy(TARGET(kARM),
......
...@@ -10,3 +10,5 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP ...@@ -10,3 +10,5 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP
add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps}) add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps})
add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps}) add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps}) add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
...@@ -12,17 +12,15 @@ ...@@ -12,17 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/read_from_array_compute.h" #include "lite/kernels/host/read_from_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void ReadFromArrayCompute::Run() { void ReadFromArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& param = this->Param<operators::ReadFromArrayParam>();
auto& param = this->Param<param_t>();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<int64_t>()[0]; int id = param.I->data<int64_t>()[0];
...@@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() { ...@@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() {
param.Out->CopyDataFrom((*param.X)[id]); param.Out->CopyDataFrom((*param.X)[id]);
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(read_from_array, REGISTER_LITE_KERNEL(read_from_array,
kARM, kHost,
kAny, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::ReadFromArrayCompute, paddle::lite::kernels::host::ReadFromArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) {LiteType::GetTensorListTy(TARGET(kHost),
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -13,20 +13,17 @@ ...@@ -13,20 +13,17 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class ReadFromArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::ReadFromArrayParam;
void Run() override; void Run() override;
~ReadFromArrayCompute() {} ~ReadFromArrayCompute() {}
...@@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private: private:
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/write_to_array_compute.h" #include "lite/kernels/host/write_to_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
void WriteToArrayCompute::Run() { void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->template Param<operators::WriteToArrayParam>(); auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
...@@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() { ...@@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() {
param.Out->at(id).CopyDataFrom(*param.X); param.Out->at(id).CopyDataFrom(*param.X);
} }
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(write_to_array, REGISTER_LITE_KERNEL(write_to_array,
kARM, kHost,
kAny, kAny,
kNCHW, kAny,
paddle::lite::kernels::arm::WriteToArrayCompute, paddle::lite::kernels::host::WriteToArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X",
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) {LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize(); .Finalize();
...@@ -13,17 +13,16 @@ ...@@ -13,17 +13,16 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace host {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class WriteToArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
void Run() override; void Run() override;
...@@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { ...@@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private: private:
}; };
} // namespace arm } // namespace host
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -26,12 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { ...@@ -26,12 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
return true; return true;
} }
bool ReadFromArrayOp::InferShapeImpl() const { bool ReadFromArrayOp::InferShapeImpl() const { return true; }
int id = param_.I->data<int64_t>()[0];
auto out_dims = (*param_.X)[id].dims();
param_.Out->Resize(out_dims);
return true;
}
bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc, bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -26,13 +26,7 @@ bool WriteToArrayOp::CheckShape() const { ...@@ -26,13 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
return true; return true;
} }
bool WriteToArrayOp::InferShapeImpl() const { bool WriteToArrayOp::InferShapeImpl() const { return true; }
int id = param_.I->data<int64_t>()[0];
if (param_.Out->size() < id + 1) {
param_.Out->resize(id + 1);
}
return true;
}
bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto inputs = opdesc.Input("X").front(); auto inputs = opdesc.Input("X").front();
......
...@@ -88,7 +88,7 @@ TEST(ReadFromArray, precision) { ...@@ -88,7 +88,7 @@ TEST(ReadFromArray, precision) {
Place place; Place place;
float abs_error = 1e-5; float abs_error = 1e-5;
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
place = TARGET(kARM); place = TARGET(kHost);
#else #else
return; return;
#endif #endif
......
...@@ -85,7 +85,7 @@ TEST(WriteToArray, precision) { ...@@ -85,7 +85,7 @@ TEST(WriteToArray, precision) {
Place place; Place place;
float abs_error = 1e-5; float abs_error = 1e-5;
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
place = TARGET(kARM); place = TARGET(kHost);
#else #else
return; return;
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册