未验证 提交 e40d56c3 编写于 作者: F flame 提交者: GitHub

anakin subgraph engine (#15774)

* add anakin subgraph engine

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* add initial op converter

* update

* update

* fix op register compile error

* update
test=develop

* update
上级 613d9d07
......@@ -16,6 +16,7 @@ add_subdirectory(utils)
if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
# add_subdirectory(anakin)
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......
cc_library(anakin_engine SRCS engine.cc)
target_link_libraries(anakin_engine anakin anakin_saber_common)
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
add_subdirectory(convert)
cc_library(anakin_op_converter SRCS fc.cc registrar.cc DEPS anakin_engine framework_proto scope)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/fc.h"
namespace paddle {
namespace inference {
namespace anakin {
void FcOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Out").size(), 1);
auto x_name = op_desc.Input("X").front();
PADDLE_ENFORCE(x_name.size() > 0);
auto *y_v = scope.FindVar(op_desc.Input("Y").front());
PADDLE_ENFORCE_NOT_NULL(y_v);
auto *y_t = y_v->GetMutable<framework::LoDTensor>();
auto shape = framework::vectorize2int(y_t->dims());
}
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
class FcOpConverter : public AnakinOpConverter {
public:
FcOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
bool test_mode) override;
virtual ~FcOpConverter() {}
private:
};
static Registrar<FcOpConverter> register_fc_op_converter("fc");
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "framework/core/types.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/anakin/convert/registrar.h"
#include "paddle/fluid/inference/anakin/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "saber/saber_types.h"
namespace paddle {
namespace inference {
namespace anakin {
using AnakinNvEngine =
AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
class AnakinOpConverter {
public:
AnakinOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope, bool test_mode) {}
void ConvertOp(const framework::proto::OpDesc &op,
const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine,
bool test_mode = false) {
framework::OpDesc op_desc(op, nullptr);
std::string op_type = op_desc.Type();
std::shared_ptr<AnakinOpConverter> it{nullptr};
if (op_type == "mul") {
PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL);
std::string Y = op_desc.Input("Y")[0];
std::cout << Y << parameters.count(Y) << std::endl;
if (parameters.count(Y)) {
it = OpRegister::instance()->Get("fc");
}
}
if (!it) {
it = OpRegister::instance()->Get(op_type);
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_type);
it->SetEngine(engine);
(*it)(op, scope, test_mode);
}
void ConvertBlock(const framework::proto::BlockDesc &block,
const std::unordered_set<std::string> &parameters,
const framework::Scope &scope, AnakinNvEngine *engine) {
std::unique_lock<std::mutex> lock(mutex_);
for (auto i = 0; i < block.ops_size(); i++) {
auto &op = block.ops(i);
ConvertOp(op, parameters, scope, engine);
}
}
void SetEngine(AnakinNvEngine *engine) { engine_ = engine; }
virtual ~AnakinOpConverter() {}
protected:
bool test_mode_;
AnakinNvEngine *engine_{nullptr};
private:
std::unordered_map<std::string, AnakinOpConverter *> converters_;
framework::Scope *scope_{nullptr};
std::mutex mutex_;
};
} // namespace anakin
} // namespace inference
} // namespace paddle
#define REGISTER_ANAKIN_OP_CONVERTER(op_type__, Converter__) \
struct anakin_##op_type__##_converter \
: public ::paddle::framework::Registrar { \
anakin_##op_type__##_converter() { \
::paddle::inference:: \
Registry<paddle::inference::anakin::AnakinOpConverter>::Register< \
::paddle::inference::anakin::Converter__>(#op_type__); \
} \
}; \
anakin_##op_type__##_converter anakin_##op_type__##_converter__; \
int TouchConverterRegister_anakin_##op_type__() { \
anakin_##op_type__##_converter__.Touch(); \
return 0; \
}
#define USE_ANAKIN_CONVERTER(op_type__) \
extern int TouchConverterRegister_anakin_##op_type__(); \
static int use_op_converter_anakin_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_anakin_##op_type__();
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/registrar.h"
namespace paddle {
namespace inference {
namespace anakin {
std::shared_ptr<AnakinOpConverter> OpRegister::Get(const std::string &name) {
auto it = registry_.find(name);
if (it == registry_.end()) return nullptr;
return it->second();
}
OpRegister *OpRegister::instance() {
static OpRegister factory;
return &factory;
}
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <utility>
namespace paddle {
namespace inference {
namespace anakin {
class AnakinOpConverter;
class OpRegister {
public:
OpRegister() = default;
std::shared_ptr<AnakinOpConverter> Get(const std::string &name);
static OpRegister *instance();
void OpRegisterFn(const std::string &name,
std::function<std::shared_ptr<AnakinOpConverter>()> fn) {
registry_[name] = fn;
}
private:
using RegisterFnType = std::function<std::shared_ptr<AnakinOpConverter>()>;
std::map<std::string, std::function<std::shared_ptr<AnakinOpConverter>()>>
registry_;
};
template <typename T, typename... Args>
class Registrar {
public:
Registrar(const std::string &name, Args... args) {
std::shared_ptr<AnakinOpConverter> converter =
std::make_shared<T>(std::move(args)...);
OpRegister::instance()->OpRegisterFn(name,
[converter]() { return converter; });
}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/fc.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace anakin {
TEST(fc_op, test) {
auto it = OpRegister::instance()->Get("fc");
ASSERT_TRUE(it != nullptr);
std::unordered_set<std::string> parameters({"mul_y"});
framework::Scope scope;
AnakinConvertValidation validator(parameters, scope);
validator.DeclInputVar("mul_x", {1, 1, 1, 1});
validator.DeclParamVar("mul_y", {1, 1, 1, 2});
validator.DeclOutputVar("mul_out", {1, 1, 1, 2});
// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul_x"});
desc.SetInput("Y", {"mul_y"});
desc.SetOutput("Out", {"mul_out"});
int num_flatten_dims = 3;
desc.SetAttr("x_num_col_dims", num_flatten_dims);
validator.SetOp(*desc.Proto());
validator.Execute(10);
}
} // namespace anakin
} // namespace inference
} // namespace paddle
USE_OP(mul);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/anakin/engine.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/enforce.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
/*
* Get a random float value between [low, high]
*/
float random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
const platform::DeviceContext& ctx) {
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(dims);
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
for (size_t i = 0; i < num_elements; i++) {
*(temp_data + i) = random(0., 1.);
}
TensorCopySync(temp_tensor, place, tensor);
}
/*
* Help to validate the correctness between Fluid Op and the corresponding
* anakin
* layer.
*/
class AnakinConvertValidation {
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
public:
AnakinConvertValidation() = delete;
AnakinConvertValidation(const std::unordered_set<std::string>& parameters,
const framework::Scope& scope)
: parameters_(parameters), scope_(scope), place_(0) {
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
}
// Declare a Variable as input with random initialization.
void DeclInputVar(const std::string& name,
const std::vector<int> tensor_dims) {
DeclVar(name, tensor_dims);
// should decalre anakin input here.
}
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
// should declare anakin output here.
}
void DeclVar(const std::string& name, const std::vector<int> dim_vec) {
platform::CUDADeviceContext ctx(place_);
auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place_, ctx);
}
void SetOp(const framework::proto::OpDesc& desc) {
op_ = framework::OpRegistry::CreateOp(desc);
op_desc_.reset(new framework::OpDesc(desc, nullptr));
// should init anakin engine here.
Singleton<AnakinOpConverter>::Global().ConvertOp(
desc, parameters_, scope_, engine_.get(), true /*test_mode*/);
engine_->Freeze();
for (const auto& input : op_desc_->InputArgumentNames()) {
if (parameters_.count(input)) continue;
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(scope_,
input);
auto t_shape = framework::vectorize2int(t.dims());
engine_->SetInputShape(input, t_shape);
}
engine_->Optimize();
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
// the outputs specified in the op des are only used during training,
// so we should neglect those output during inference.
void Execute(int batch_size,
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op
platform::CUDADeviceContext ctx(place_);
op_->Run(scope_, place_);
for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
std::vector<float> fluid_out;
auto* var = scope_.FindVar(output);
auto* tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out);
size_t fluid_out_size = fluid_out.size();
for (size_t i = 0; i < fluid_out_size; i++) {
std::cout << fluid_out[i] << std::endl;
}
}
}
framework::Scope& scope() { return scope_; }
private:
std::unique_ptr<AnakinNvEngineT> engine_{nullptr};
cudaStream_t stream_;
std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_;
const std::unordered_set<std::string>& parameters_;
framework::Scope& scope_;
platform::CUDAPlace place_;
};
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/engine.h"
#include <algorithm>
#include <cstring>
#include <map>
#include <utility>
#include "paddle/fluid/framework/ddim.h"
using anakin::Precision;
using anakin::OpRunType;
using paddle::framework::LoDTensor;
template <typename T, Precision P, OpRunType O>
using AnakinNetT = anakin::Net<T, P, O>;
template <typename T, Precision P>
using AnakinGraphT = anakin::graph::Graph<T, P>;
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary)
: graph_(new AnakinGraphT<TargetT, PrecisionType>()),
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::~AnakinEngine() {}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::SetInputShape(
const std::string &name, std::vector<int> shape) {
graph_->AddOpAttr<::anakin::PTuple<int>>(name, "input_shape",
std::move(shape));
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::InitGraph() {
net_->init(*graph_);
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp(
const std::string &name, const std::string &type,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs) {
PADDLE_ENFORCE(graph_->AddOp(name, type, inputs, outputs), "Add operation.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs) {
for (const auto &input : inputs) {
auto *tensor = input.second;
auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims());
::anakin::saber::Shape anakin_shape(shape);
auto *anakin_input = net_->get_in(input.first);
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_shape);
anakin_input->share_from(tmp_anakin_tensor);
}
for (const auto &output : outputs) {
auto *tensor = output.second;
auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims());
::anakin::saber::Shape anakin_shape(shape);
auto *anakin_output = net_->get_out(output.first);
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_shape);
anakin_output->share_from(tmp_anakin_tensor);
}
net_->prediction();
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Freeze() {
PADDLE_ENFORCE(graph_->Freeze(), "Freeze anakin subgraph.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Optimize() {
PADDLE_ENFORCE(graph_->Optimize(), "Graph optimization.");
}
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
std::unique_ptr<AnakinEngine<TargetT, PrecisionType, RunType>>
AnakinEngine<TargetT, PrecisionType, RunType>::Clone() {
auto *engine = new AnakinEngine();
engine->net_ = std::move(net_->Clone());
return std::unique_ptr<AnakinEngine>(engine);
}
template class AnakinEngine<::anakin::saber::NV, ::anakin::Precision::FP32>;
} // namespace anakin
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "framework/core/net/net.h"
#include "framework/core/types.h"
#include "framework/graph/graph.h"
#include "saber/saber_types.h"
namespace anakin {
template <typename, Precision, OpRunType>
class Net;
namespace graph {
template <typename, Precision>
class Graph;
} // namespace graph
} // namespace anakin
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionType,
::anakin::OpRunType RunType = ::anakin::OpRunType::ASYNC>
class AnakinEngine {
public:
explicit AnakinEngine(bool need_summary = false);
~AnakinEngine();
void InitGraph();
void SetInputShape(const std::string &name, std::vector<int> shape);
void AddOp(const std::string &name, const std::string &type,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs);
template <typename T>
void AddOpAttr(const std::string &op_name, const std::string &attr_name,
const T &attr_value) {
PADDLE_ENFORCE(graph_->AddOpAttr(op_name, attr_name, attr_value),
"Add operation's attribution.");
}
std::unique_ptr<AnakinEngine> Clone();
void Freeze();
void Optimize();
void Execute(const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs);
private:
using NetT = ::anakin::Net<TargetT, PrecisionType, RunType>;
using GraphT = ::anakin::graph::Graph<TargetT, PrecisionType>;
std::unique_ptr<GraphT> graph_;
std::unique_ptr<NetT> net_;
};
} // namespace anakin
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <map>
#include "framework/core/net/net.h"
#include "framework/graph/graph.h"
#include "framework/graph/graph_global_mem.h"
#include "paddle/fluid/inference/anakin/engine.h"
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::Precision;
using anakin::saber::NV;
using anakin::saber::X86;
using anakin::saber::Shape;
using anakin::PBlock;
using anakin::PTuple;
namespace paddle {
namespace inference {
namespace anakin {
class TestAnakinEngine : public ::testing::Test {
protected:
void SetUp() override;
void TearDown() override {}
protected:
using AnakinNvEngineT = AnakinEngine<NV, Precision::FP32>;
std::unique_ptr<AnakinNvEngineT> engine_{nullptr};
};
void TestAnakinEngine::SetUp() {
engine_.reset(new AnakinEngine<NV, Precision::FP32>(true));
TEST_F(TestAnakinEngine, Execute) {
engine_->AddOp("op1", "Dense", {"x"}, {"y"});
engine_->AddOpAttr("op1", "out_dim", 2);
engine_->AddOpAttr("op1", "bias_term", false);
engine_->AddOpAttr("op1", "axis", 1);
std::vector<int> shape = {1, 1, 1, 2};
Shape tmp_shape(shape);
auto *weight1 =
GraphGlobalMem<NV>::Global().template new_block<AK_FLOAT>(tmp_shape);
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
cpu_data[0] = 2.;
weight1->d_tensor().set_shape(tmp_shape);
weight1->d_tensor().copy_from(weight1->h_tensor());
engine_->AddOpAttr("op1", "weight_1", *weight1);
engine_->Freeze();
engine_->SetInputShape("x", {1, 1, 1, 1});
engine_->Optimize();
engine_->InitGraph();
framework::LoDTensor x;
framework::LoDTensor y;
x.Resize({1, 1, 1, 1});
y.Resize({1, 1, 1, 2});
auto *x_data = x.mutable_data<float>(platform::CUDAPlace());
float x_data_cpu[] = {1.};
cudaMemcpy(x_data, x_data_cpu, sizeof(float), cudaMemcpyHostToDevice);
std::map<std::string, framework::LoDTensor *> inputs = {{"x", &x}};
auto *y_data = y.mutable_data<float>(platform::CUDAPlace());
std::map<std::string, framework::LoDTensor *> outputs = {{"y", &y}};
engine_->Execute(inputs, outputs);
auto *y_data_gpu = y_data;
float y_data_cpu[2];
cudaMemcpy(y_data_cpu, y_data_gpu, sizeof(float) * 2,
cudaMemcpyDeviceToHost);
LOG(INFO) << "output value: " << y_data_cpu[0] << ", " << y_data_cpu[1];
}
}
} // namespace anakin
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册