未验证 提交 afefe9cf 编写于 作者: Z zhupengyang 提交者: GitHub

assign support tensor_array (#3436)

上级 2168ff38
......@@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k
add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 3. extra kernels
add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -12,3 +12,4 @@ add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_k
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
......@@ -12,29 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/assign_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/host/assign_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
void AssignCompute::Run() {
auto& param = Param<param_t>();
if (param.X != nullptr) {
param.Out->CopyDataFrom(*param.X);
} else if (param.X_array != nullptr) {
auto x_array = param.X_array;
auto out_array = param.Out_array;
out_array->resize(x_array->size());
for (size_t i = 0; i < x_array->size(); i++) {
out_array->at(i).CopyDataFrom(x_array->at(i));
}
} else {
LOG(FATAL) << "x or x_array of assign must be set.";
}
}
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
assign, kARM, kAny, kNCHW, paddle::lite::kernels::arm::AssignCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
assign, kHost, kAny, kAny, paddle::lite::kernels::host::AssignCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
......@@ -15,14 +15,15 @@
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/assign_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class AssignCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::AssignParam;
......@@ -31,7 +32,7 @@ class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
virtual ~AssignCompute() = default;
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -27,20 +27,33 @@ bool AssignOpLite::CheckShape() const {
}
bool AssignOpLite::InferShapeImpl() const {
lite::DDim input_dims;
input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims));
if (param_.X != nullptr) {
param_.Out->Resize(param_.X->dims());
} else if (param_.X_array != nullptr) {
param_.Out_array->resize(param_.Out_array->size());
} else {
LOG(FATAL) << "x or x_array must be set.";
}
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AssignOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
auto x_name = op_desc.Input("X").front();
auto out_name = op_desc.Output("Out").front();
param_.X = scope->FindVar(input)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
auto x_var = scope->FindVar(x_name);
if (x_var->IsType<Tensor>()) {
param_.X = scope->FindTensor(x_name);
param_.Out = scope->FindMutableTensor(out_name);
} else if (x_var->IsType<std::vector<Tensor>>()) {
param_.X_array = x_var->GetMutable<std::vector<Tensor>>();
param_.Out_array =
scope->FindVar(out_name)->GetMutable<std::vector<Tensor>>();
} else {
LOG(FATAL) << "X type for assign op is unsupported. Expected type is "
"tensor or tensor_array.";
}
return true;
}
......
......@@ -1279,8 +1279,13 @@ struct GatherParam : ParamBase {
/// ----------------------- assign operators -----------------------
struct AssignParam : ParamBase {
const lite::Tensor* X{};
lite::Tensor* Out{};
// for tensor
const lite::Tensor* X{nullptr};
lite::Tensor* Out{nullptr};
// for tensor_array
const std::vector<lite::Tensor>* X_array{nullptr};
std::vector<lite::Tensor>* Out_array{nullptr};
};
/// ----------------------- roi_align operators -----------------------
......
......@@ -69,7 +69,7 @@ void TestAssign(const Place& place) {
TEST(Assign, precision) {
Place place;
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
place = TARGET(kHost);
#else
return;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册