提交 6871c0c7 编写于 作者: J jiaopu

BatchSize changeble

上级 104dd1d6
......@@ -47,7 +47,6 @@ class Graph {
CNRT_CALL(cnrtCreateNotifier(&notifier_end_));
#endif
}
~Graph() {
FreeConstData();
CNML_CALL(cnmlDestroyFusionOp(&fusion_op_));
......@@ -62,7 +61,6 @@ class Graph {
<< " process:" << total_time / time_log_.size() << std::endl;
#endif
}
// Data node
std::shared_ptr<MLUTensor> AddNode(
const std::string& name,
......@@ -84,6 +82,10 @@ class Graph {
void AddInput(std::shared_ptr<MLUTensor> tensor) {
inputs_.push_back(tensor->mlu_tensor());
input_tensors_.push_back(tensor);
constexpr int input_dimNb = 4;
bool input_dim_mutable[4] = {true, false, false, false};
cnmlSetTensorDimMutable(
tensor->mlu_tensor(), input_dim_mutable, input_dimNb);
}
void AddOutput(std::shared_ptr<MLUTensor> tensor) {
......@@ -119,27 +121,38 @@ class Graph {
CNML_CALL(cnmlCompileFusionOp_V2(fusion_op_));
}
void Compute(cnrtInvokeFuncParam_t forward_param, cnrtQueue_t que) {
input_addrs_.resize(input_tensors_.size());
output_addrs_.resize(output_tensors_.size());
void Compute(cnrtInvokeFuncParam_t forward_param,
cnrtQueue_t que,
std::vector<std::shared_ptr<MLUTensor>> in,
std::vector<std::shared_ptr<MLUTensor>> out) {
std::vector<cnmlTensor_t> in_tensor;
std::vector<cnmlTensor_t> out_tensor;
input_addrs_.resize(in.size());
output_addrs_.resize(out.size());
for (size_t i = 0; i < input_addrs_.size(); ++i) {
input_addrs_[i] = input_tensors_[i]->mlu_data();
input_addrs_[i] = in[i]->mlu_data();
in_tensor.push_back(in[i]->mlu_tensor());
}
for (size_t i = 0; i < output_addrs_.size(); ++i) {
output_addrs_[i] = output_tensors_[i]->mlu_data();
output_addrs_[i] = out[i]->mlu_data();
out_tensor.push_back(out[i]->mlu_tensor());
}
#if PRINT_HW_TIME
thread_local float hw_time;
CNRT_CALL(cnrtPlaceNotifier(notifier_start_, que));
#endif
CNML_CALL(cnmlComputeFusionOpForward_V3(fusion_op_,
/* Because of using cnmlSetTensorDimMutable, cnmlComputeFusionOpForward_V3
* -> cnmlComputeFusionOpForward_V4 */
CNML_CALL(cnmlComputeFusionOpForward_V4(fusion_op_,
&in_tensor[0],
input_addrs_.data(),
input_addrs_.size(),
&out_tensor[0],
output_addrs_.data(),
output_addrs_.size(),
&forward_param,
que));
que,
NULL));
#if PRINT_HW_TIME
CNRT_CALL(cnrtPlaceNotifier(notifier_end_, que));
CNRT_CALL(cnrtSyncQueue(que));
......
......@@ -49,6 +49,7 @@ class MLUTensor {
return mlu_ptr_;
}
cnmlDataType_t dtype() { return mlu_dtype_; }
void set_mlu_dtype(cnmlDataType_t type) { mlu_dtype_ = type; }
const std::vector<int64_t>& get_origin_shape() const { return origin_shape_; }
......
......@@ -89,7 +89,10 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
}
graph.Compile(CNML_MLU270, 1);
graph.Compute(forward_param, queue_);
graph.Compute(forward_param,
queue_,
*(graph.MutableInputs()),
*(graph.MutableOutputs()));
CNRT_CALL(cnrtSyncQueue(queue_));
for (auto& output_name : output_var_names) {
......
......@@ -22,12 +22,16 @@
#include "lite/api/paddle_place.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
#include "lite/core/types.h"
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/tensor.h"
#include "lite/kernels/npu/bridges/engine.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/utils/env.h"
namespace paddle {
namespace lite {
......@@ -77,7 +81,9 @@ class SubgraphEngine : public subgraph::Engine {
bool InputShapeChanged() {
std::vector<std::vector<int64_t>> new_shape;
for (auto origin_itensor : origin_itensors_) {
new_shape.push_back(origin_itensor->dims().Vectorize());
auto iv = origin_itensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
}
inputs_shape_ = new_shape;
if (shape_graph_map_.count(inputs_shape_) > 0) {
......@@ -99,9 +105,10 @@ class SubgraphEngine : public subgraph::Engine {
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
for (auto& input_name : input_names_) {
auto input_tensor = scope_->FindMutableTensor(input_name);
origin_itensors_.push_back(input_tensor);
new_shape.push_back(input_tensor->dims().Vectorize());
auto iv = input_tensor->dims().Vectorize();
iv.erase(iv.begin());
new_shape.push_back(iv);
CHECK(input_tensor);
auto input_node = graph->AddNode(input_name,
......@@ -222,9 +229,20 @@ class SubgraphEngine : public subgraph::Engine {
CHECK_EQ(graph_input->size(), origin_itensors_.size());
CHECK_EQ(graph_output->size(), origin_otensors_.size());
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
graph_in;
graph_in.reserve(origin_itensors_.size());
std::vector<std::shared_ptr<paddle::lite::subgraph::mlu::MLUTensor>>
graph_out;
graph_out.reserve(origin_otensors_.size());
for (size_t i = 0; i < origin_itensors_.size(); ++i) {
graph_input->at(i)->set_mlu_ptr(
const_cast<void*>(origin_itensors_[i]->raw_data()));
paddle::lite::subgraph::mlu::MLUTensor tmp(
graph_input->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_input->at(i)->dtype());
tmp.set_mlu_ptr(const_cast<void*>(origin_itensors_[i]->raw_data()));
graph_in.push_back(
std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp));
}
for (size_t i = 0; i < origin_otensors_.size(); ++i) {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
......@@ -232,7 +250,12 @@ class SubgraphEngine : public subgraph::Engine {
origin_otensors_[i]
->mutable_data<typename paddle::lite::subgraph::mlu::FPTypeTraits<
Precision>::T>(TARGET(kMLU)));
graph_output->at(i)->set_mlu_ptr(p_data);
paddle::lite::subgraph::mlu::MLUTensor tmp(
graph_output->at(i)->get_origin_shape());
tmp.set_mlu_dtype(graph_output->at(i)->dtype());
tmp.set_mlu_ptr(p_data);
graph_out.push_back(
std::make_shared<paddle::lite::subgraph::mlu::MLUTensor>(tmp));
}
auto& mlu_context = this->ctx_->template As<MLUContext>();
......@@ -244,7 +267,7 @@ class SubgraphEngine : public subgraph::Engine {
forward_param.affinity = &affinity;
forward_param.end = CNRT_PARAM_END;
graph->Compute(forward_param, exec_queue);
graph->Compute(forward_param, exec_queue, graph_in, graph_out);
// // =========== DUMP ===================
// for (auto input_name : input_names_) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册