未验证 提交 3a04e11d 编写于 作者: Z zhupengyang 提交者: GitHub

move read_from_array and write_to_array to host (#3428)

上级 b445941f
......@@ -92,8 +92,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -114,14 +114,14 @@ struct BeamSearchDecoder {
lod.push_back(source_level_lod);
lod.push_back(sentence_level_lod);
*(id_tensor->mutable_lod()) = lod;
id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
auto id_ptr = id_tensor->mutable_data<int64_t>();
TargetCopy(
TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t));
*(score_tensor->mutable_lod()) = lod;
score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
auto score_ptr = score_tensor->mutable_data<T>();
TargetCopy(TARGET(kARM),
......
......@@ -10,3 +10,5 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP
add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps})
add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
......@@ -12,17 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/read_from_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/kernels/host/read_from_array_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
void ReadFromArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<param_t>();
auto& param = this->Param<operators::ReadFromArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<int64_t>()[0];
......@@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() {
param.Out->CopyDataFrom((*param.X)[id]);
}
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(read_from_array,
kARM,
kHost,
kAny,
kNCHW,
paddle::lite::kernels::arm::ReadFromArrayCompute,
kAny,
paddle::lite::kernels::host::ReadFromArrayCompute,
def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("X",
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
......@@ -13,20 +13,17 @@
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class ReadFromArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::ReadFromArrayParam;
void Run() override;
~ReadFromArrayCompute() {}
......@@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private:
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/write_to_array_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/kernels/host/write_to_array_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
......@@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() {
param.Out->at(id).CopyDataFrom(*param.X);
}
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(write_to_array,
kARM,
kHost,
kAny,
kNCHW,
paddle::lite::kernels::arm::WriteToArrayCompute,
kAny,
paddle::lite::kernels::host::WriteToArrayCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("I",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))})
{LiteType::GetTensorListTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
......@@ -13,17 +13,16 @@
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class WriteToArrayCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
......@@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
private:
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -26,12 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
return true;
}
bool ReadFromArrayOp::InferShapeImpl() const {
int id = param_.I->data<int64_t>()[0];
auto out_dims = (*param_.X)[id].dims();
param_.Out->Resize(out_dims);
return true;
}
bool ReadFromArrayOp::InferShapeImpl() const { return true; }
bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
......
......@@ -26,13 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
return true;
}
bool WriteToArrayOp::InferShapeImpl() const {
int id = param_.I->data<int64_t>()[0];
if (param_.Out->size() < id + 1) {
param_.Out->resize(id + 1);
}
return true;
}
bool WriteToArrayOp::InferShapeImpl() const { return true; }
bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto inputs = opdesc.Input("X").front();
......
......@@ -88,7 +88,7 @@ TEST(ReadFromArray, precision) {
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
place = TARGET(kHost);
#else
return;
#endif
......
......@@ -85,7 +85,7 @@ TEST(WriteToArray, precision) {
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
place = TARGET(kHost);
#else
return;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册