未验证 提交 a0bccd9e 编写于 作者: W WangZhen 提交者: GitHub

[JitLayer]Pybind PEFunction and call phi api in layer_test (#44465)

* Support predictor function in JitLayer

* Pybind PEFunction

* Pybind PEFunction and call phi api in layer_test

* Call sqrt phi API

* Polish flags

* Fix comments
上级 5b3f91df
...@@ -37,6 +37,7 @@ if(WITH_TESTING AND NOT WIN32) ...@@ -37,6 +37,7 @@ if(WITH_TESTING AND NOT WIN32)
COMMAND tar zxf multi_program_load.tar.gz) COMMAND tar zxf multi_program_load.tar.gz)
set(JIT_DEPS set(JIT_DEPS
phi phi
phi_api
elementwise_add_op elementwise_add_op
matmul_v2_op matmul_v2_op
activation_op activation_op
......
...@@ -38,6 +38,7 @@ class ExecutorFunction : public BaseFunction { ...@@ -38,6 +38,7 @@ class ExecutorFunction : public BaseFunction {
: info_(info), place_(place), inner_exe_(place_) { : info_(info), place_(place), inner_exe_(place_) {
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, &scope_); utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, &scope_);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
info_->RemoveDescFeedFetch();
} }
~ExecutorFunction() noexcept {} ~ExecutorFunction() noexcept {}
......
...@@ -62,8 +62,6 @@ FunctionInfo::FunctionInfo(const std::string& func_name, ...@@ -62,8 +62,6 @@ FunctionInfo::FunctionInfo(const std::string& func_name,
for (auto& out_name : program_desc_.GetFetchTargetNames()) { for (auto& out_name : program_desc_.GetFetchTargetNames()) {
schema_.AddOutputArg(out_name); schema_.AddOutputArg(out_name);
} }
// remove feed fetch op
utils::RemoveFeedFetch(&program_desc_);
} }
const std::string& FunctionInfo::FunctionName() const { return func_name_; } const std::string& FunctionInfo::FunctionName() const { return func_name_; }
...@@ -84,5 +82,9 @@ const std::vector<std::string> FunctionInfo::OutputArgNames() const { ...@@ -84,5 +82,9 @@ const std::vector<std::string> FunctionInfo::OutputArgNames() const {
return schema_.OutputArgNames(); return schema_.OutputArgNames();
} }
void FunctionInfo::RemoveDescFeedFetch() {
utils::RemoveFeedFetch(&program_desc_);
}
} // namespace jit } // namespace jit
} // namespace paddle } // namespace paddle
...@@ -70,6 +70,8 @@ class FunctionInfo { ...@@ -70,6 +70,8 @@ class FunctionInfo {
const std::vector<std::string> OutputArgNames() const; const std::vector<std::string> OutputArgNames() const;
void RemoveDescFeedFetch();
private: private:
std::string func_name_; std::string func_name_;
std::vector<std::string> param_names_; std::vector<std::string> param_names_;
......
...@@ -12,17 +12,20 @@ ...@@ -12,17 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include <string> #include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/jit/function_utils.h"
#include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/jit/serializer.h"
...@@ -52,7 +55,7 @@ namespace paddle { ...@@ -52,7 +55,7 @@ namespace paddle {
namespace jit { namespace jit {
using DenseTensor = phi::DenseTensor; using DenseTensor = phi::DenseTensor;
std::vector<DenseTensor> PrepareInputs(const phi::Place& place) { std::vector<Tensor> PrepareInputs(const phi::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place); auto& dev_ctx = *pool.Get(place);
...@@ -61,7 +64,7 @@ std::vector<DenseTensor> PrepareInputs(const phi::Place& place) { ...@@ -61,7 +64,7 @@ std::vector<DenseTensor> PrepareInputs(const phi::Place& place) {
t.mutable_data<float>(place); t.mutable_data<float>(place);
phi::funcs::set_constant(dev_ctx, &t, 2.); phi::funcs::set_constant(dev_ctx, &t, 2.);
return {t}; return utils::ToTensors({t});
} }
TEST(CpuLayerTest, Construct) { TEST(CpuLayerTest, Construct) {
...@@ -78,34 +81,38 @@ TEST(CpuLayerTest, Construct) { ...@@ -78,34 +81,38 @@ TEST(CpuLayerTest, Construct) {
outs = (*func)(inputs); outs = (*func)(inputs);
out_data = outs[0].data<float>(); out_data = outs[0].data<float>();
EXPECT_NEAR(out_data[0], 1.41562390, 1e-6); EXPECT_NEAR(out_data[0], 1.41562390, 1e-6);
auto pow_out =
paddle::experimental::pow(outs[0], paddle::experimental::Scalar(2));
out_data = pow_out.data<float>();
EXPECT_NEAR(out_data[0], pow(1.41562390, 2.0), 1e-6);
} }
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
TEST(GpuLayerTest, Construct) { TEST(GpuLayerTest, Construct) {
auto place = phi::GPUPlace(); auto place = phi::GPUPlace();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place);
const auto* dev_ctx_gpu = static_cast<const phi::GPUContext*>(&dev_ctx);
DenseTensor cpu_dense_tensor;
std::string path = "./multi_program_load/export"; std::string path = "./multi_program_load/export";
auto layer = jit::Load(path, place); auto layer = jit::Load(path, place);
auto inputs = PrepareInputs(place); auto inputs = PrepareInputs(place);
auto outs = layer.forward(inputs); auto outs = layer.forward(inputs);
auto out_dense_tensor = outs[0]; auto gpu_tensor = outs[0];
phi::Copy( auto cpu_tensor =
*dev_ctx_gpu, out_dense_tensor, phi::CPUPlace(), true, &cpu_dense_tensor); paddle::experimental::copy_to(gpu_tensor, phi::CPUPlace(), true);
auto out_data = cpu_dense_tensor.data<float>(); auto out_data = cpu_tensor.data<float>();
EXPECT_NEAR(out_data[0], 0.02194316, 1e-6); EXPECT_NEAR(out_data[0], 0.02194316, 1e-6);
auto func = layer.Function("infer"); auto func = layer.Function("infer");
outs = (*func)(inputs); outs = (*func)(inputs);
out_dense_tensor = outs[0]; gpu_tensor = outs[0];
phi::Copy( cpu_tensor = paddle::experimental::copy_to(gpu_tensor, phi::CPUPlace(), true);
*dev_ctx_gpu, out_dense_tensor, phi::CPUPlace(), true, &cpu_dense_tensor); out_data = cpu_tensor.data<float>();
out_data = cpu_dense_tensor.data<float>();
EXPECT_NEAR(out_data[0], 1.41562390, 1e-6); EXPECT_NEAR(out_data[0], 1.41562390, 1e-6);
auto sqrt_out = paddle::experimental::sqrt(outs[0]);
cpu_tensor = paddle::experimental::copy_to(sqrt_out, phi::CPUPlace(), true);
out_data = cpu_tensor.data<float>();
EXPECT_NEAR(out_data[0], sqrt(1.41562390), 1e-6);
} }
#endif #endif
......
...@@ -39,6 +39,7 @@ class PEFunction : public BaseFunction { ...@@ -39,6 +39,7 @@ class PEFunction : public BaseFunction {
: info_(info), place_(place) { : info_(info), place_(place) {
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, &scope_); utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, &scope_);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
info_->RemoveDescFeedFetch();
} }
~PEFunction() noexcept {} ~PEFunction() noexcept {}
...@@ -51,13 +52,14 @@ class PEFunction : public BaseFunction { ...@@ -51,13 +52,14 @@ class PEFunction : public BaseFunction {
std::vector<DenseTensor> operator()(const std::vector<DenseTensor> &inputs) { std::vector<DenseTensor> operator()(const std::vector<DenseTensor> &inputs) {
std::string prog_string; std::string prog_string;
std::hash<std::string> string_hash; std::hash<std::string> string_hash;
auto &program_desc = info_->ProgramDesc(); auto &program_desc = info_->ProgramDesc();
// TODO(dev): Serialize is very slow. // TODO(dev): Serialize is very slow.
const_cast<framework::ProgramDesc *>(&program_desc) const_cast<framework::ProgramDesc *>(&program_desc)
->Proto() ->Proto()
->SerializePartialToString(&prog_string); ->SerializePartialToString(&prog_string);
int64_t program_id = static_cast<int64_t>(string_hash(prog_string)); int64_t program_id = static_cast<int64_t>(string_hash(prog_string));
const framework::BlockDesc &global_block = program_desc.Block(0); const framework::BlockDesc &global_block = program_desc.Block(0);
int64_t start_op_index = 0; int64_t start_op_index = 0;
int64_t end_op_index = static_cast<int64_t>(global_block.OpSize()); int64_t end_op_index = static_cast<int64_t>(global_block.OpSize());
...@@ -97,6 +99,8 @@ class PEFunction : public BaseFunction { ...@@ -97,6 +99,8 @@ class PEFunction : public BaseFunction {
return res; return res;
} }
const std::shared_ptr<FunctionInfo> &Info() const { return info_; }
private: private:
std::shared_ptr<FunctionInfo> info_; std::shared_ptr<FunctionInfo> info_;
framework::Scope scope_; framework::Scope scope_;
......
...@@ -19,8 +19,11 @@ ...@@ -19,8 +19,11 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/jit/executor_function.h" #include "paddle/fluid/jit/executor_function.h"
#include "paddle/fluid/jit/pe_function.h"
#include "paddle/fluid/jit/serializer_utils.h" #include "paddle/fluid/jit/serializer_utils.h"
DECLARE_string(jit_engine_type);
namespace paddle { namespace paddle {
namespace jit { namespace jit {
...@@ -55,9 +58,19 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -55,9 +58,19 @@ Layer Deserializer::operator()(const std::string& path,
Layer layer = Layer(infos, params_dict, place); Layer layer = Layer(infos, params_dict, place);
for (auto& info : infos) { for (auto& info : infos) {
layer.SetFunction( if (FLAGS_jit_engine_type == "Executor") {
info->FunctionName(), VLOG(3) << "Add function type: ExecutorFunction.";
utils::MakeFunction<ExecutorFunction>(info, params_dict, place)); layer.SetFunction(
info->FunctionName(),
utils::MakeFunction<ExecutorFunction>(info, params_dict, place));
} else if (FLAGS_jit_engine_type == "PE") {
VLOG(3) << "Add function type: PEFunction.";
layer.SetFunction(
info->FunctionName(),
utils::MakeFunction<PEFunction>(info, params_dict, place));
} else {
PD_THROW("Invalid JitLayer funciton type.");
}
} }
return layer; return layer;
...@@ -85,7 +98,7 @@ void Deserializer::ReadAttributeData(const std::string& file_path, ...@@ -85,7 +98,7 @@ void Deserializer::ReadAttributeData(const std::string& file_path,
Name2VariableMap* attrs_dict) const {} Name2VariableMap* attrs_dict) const {}
framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) { framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) {
VLOG(3) << "LoadProgram " << file_name; VLOG(3) << "LoadProgram from: " << file_name;
std::ifstream fin(file_name, std::ios::in | std::ios::binary); std::ifstream fin(file_name, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' '); std::string buffer(fin.tellg(), ' ');
......
...@@ -916,3 +916,18 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -916,3 +916,18 @@ PADDLE_DEFINE_EXPORTED_bool(
einsum_opt, einsum_opt,
false, false,
"EinsumOp backward will be speedup at the expense of more gpu memory."); "EinsumOp backward will be speedup at the expense of more gpu memory.");
/**
* JitLayer related FLAG
* Name: FLAGS_jit_engine_type
* Since Version: 2.3.0
* Value Range: string, {Executor, PE},
* default=PE
* Example:
* Note:
* FLAGS_jit_engine_type == Executor, using ExecutorFunction by default
* FLAGS_jit_engine_type == PE, using PEFunction by default
*/
PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"PE",
"Choose default funciton type in JitLayer.");
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/jit/executor_function.h"
#include "paddle/fluid/jit/pe_function.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
...@@ -52,6 +54,7 @@ extern PyTypeObject* g_framework_tensor_pytype; ...@@ -52,6 +54,7 @@ extern PyTypeObject* g_framework_tensor_pytype;
extern PyTypeObject* g_framework_lodtensorarray_pytype; extern PyTypeObject* g_framework_lodtensorarray_pytype;
extern PyTypeObject* g_custom_op_kernel_ctx_pytype; extern PyTypeObject* g_custom_op_kernel_ctx_pytype;
extern PyTypeObject* g_executor_function_pytype; extern PyTypeObject* g_executor_function_pytype;
extern PyTypeObject* g_pe_function_pytype;
int TensorDtype2NumpyDtype(phi::DataType dtype) { int TensorDtype2NumpyDtype(phi::DataType dtype) {
switch (dtype) { switch (dtype) {
...@@ -234,6 +237,9 @@ std::shared_ptr<jit::BaseFunction> CastPyArg2BaseFunction(PyObject* obj, ...@@ -234,6 +237,9 @@ std::shared_ptr<jit::BaseFunction> CastPyArg2BaseFunction(PyObject* obj,
obj, reinterpret_cast<PyObject*>(g_executor_function_pytype))) { obj, reinterpret_cast<PyObject*>(g_executor_function_pytype))) {
return ::pybind11::handle(obj) return ::pybind11::handle(obj)
.cast<std::shared_ptr<jit::ExecutorFunction>>(); .cast<std::shared_ptr<jit::ExecutorFunction>>();
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_pe_function_pytype))) {
return ::pybind11::handle(obj).cast<std::shared_ptr<jit::PEFunction>>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be "
......
...@@ -19,7 +19,7 @@ typedef SSIZE_T ssize_t; ...@@ -19,7 +19,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/jit/executor_function.h" #include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/jit/executor_function.h" #include "paddle/fluid/jit/executor_function.h"
#include "paddle/fluid/jit/function_schema.h" #include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/pe_function.h"
#include "paddle/fluid/jit/serializer.h" #include "paddle/fluid/jit/serializer.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -29,6 +30,7 @@ namespace paddle { ...@@ -29,6 +30,7 @@ namespace paddle {
namespace pybind { namespace pybind {
PyTypeObject *g_executor_function_pytype = nullptr; PyTypeObject *g_executor_function_pytype = nullptr;
PyTypeObject *g_pe_function_pytype = nullptr;
using Variable = paddle::framework::Variable; using Variable = paddle::framework::Variable;
void BindJit(pybind11::module *m) { void BindJit(pybind11::module *m) {
...@@ -44,6 +46,11 @@ void BindJit(pybind11::module *m) { ...@@ -44,6 +46,11 @@ void BindJit(pybind11::module *m) {
reinterpret_cast<PyTypeObject *>(executor_function.ptr()); reinterpret_cast<PyTypeObject *>(executor_function.ptr());
executor_function.def("info", &jit::ExecutorFunction::Info); executor_function.def("info", &jit::ExecutorFunction::Info);
py::class_<jit::PEFunction, std::shared_ptr<jit::PEFunction>> pe_function(
*m, "PEFunction", R"DOC(PEFunction Class.)DOC");
g_pe_function_pytype = reinterpret_cast<PyTypeObject *>(pe_function.ptr());
pe_function.def("info", &jit::PEFunction::Info);
py::class_<jit::FunctionInfo, std::shared_ptr<jit::FunctionInfo>>( py::class_<jit::FunctionInfo, std::shared_ptr<jit::FunctionInfo>>(
*m, "FunctionInfo", R"DOC(FunctionInfo Class.)DOC") *m, "FunctionInfo", R"DOC(FunctionInfo Class.)DOC")
.def("name", &jit::FunctionInfo::FunctionName) .def("name", &jit::FunctionInfo::FunctionName)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册