提交 0e77cd63 编写于 作者: Z zhupengyang 提交者: GitHub

avoid reusing non-tensor in memery_optimize_pass (#3111)

上级 c79a0954
......@@ -78,6 +78,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
// Collect the invalid input and output variables that will not be reused.
std::unordered_set<std::string> invalid_var_names;
for (auto& op_node : graph->StmtTopologicalOrder()) {
// variables of invalid_op_nodes wil not be reused
if (!op_node->IsStmt()) continue;
auto op_info = op_node->AsStmt().op_info();
auto op_type = op_info->Type();
......@@ -120,6 +121,13 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
}
}
// non-tensor(like tensor_array) variables will not be reused
for (auto& node : graph->nodes()) {
if (node.IsArg() && !node.arg()->type->IsTensor()) {
invalid_var_names.insert(node.arg()->name);
}
}
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->IsStmt()) {
std::vector<Node*> var_nodes(op_node->inlinks.begin(),
......
......@@ -182,32 +182,32 @@ void ElementwiseSubActivationCompute::Run() {
template <typename T, PrecisionType PType>
void ElementwiseMulCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>();
if (param.X->precision() == PRECISION(kFloat)) {
auto* x_data = param.X->template data<float>();
auto* y_data = param.Y->template data<float>();
auto* out_data = param.Out->template mutable_data<float>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<float>(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<float>(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_mul<float>(
x_data, y_data, out_data, x_dims.production());
}
} else if (param.X->precision() == PRECISION(kInt64)) {
lite::arm::math::elementwise_compute_basic<int64_t>(param, "mul", "");
auto* x_data = param.X->template data<T>();
auto* y_data = param.Y->template data<T>();
auto* out_data = param.Out->template mutable_data<T>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<T>(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<T>(
x_data, y_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupport input type";
lite::arm::math::elementwise_mul<T>(
x_data, y_data, out_data, x_dims.production());
}
}
template <>
void ElementwiseMulCompute<int64_t, PRECISION(kInt64)>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>();
lite::arm::math::elementwise_compute_basic<int64_t>(param, "mul", "");
}
void ElementwiseMulActivationCompute::Run() {
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
......@@ -420,6 +420,16 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
using elementwise_mul_int64 =
paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t,
PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_mul_activation,
kARM,
......
......@@ -730,15 +730,15 @@ struct IncrementParam {
};
struct WriteToArrayParam {
const lite::Tensor* X{};
const lite::Tensor* I{};
std::vector<lite::Tensor>* Out{};
const lite::Tensor* X{nullptr};
const lite::Tensor* I{nullptr};
std::vector<lite::Tensor>* Out{nullptr};
};
struct ReadFromArrayParam {
const std::vector<lite::Tensor>* X{};
const lite::Tensor* I{};
lite::Tensor* Out{};
const std::vector<lite::Tensor>* X{nullptr};
const lite::Tensor* I{nullptr};
lite::Tensor* Out{nullptr};
};
struct BeamSearchParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册