未验证 提交 01247e33 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Performance] Optimize custom operator performance (#52597)

* [Opt Performance] Optimize custom operator performance, reconstruct python API auto-gen, add cache and use const inference

* opt AutoGradMeta implementation

* remove profiler codes

* fix unit test

* change year, 2021->2023

* fix int64_t parse bug
上级 90c3bddf
......@@ -236,7 +236,8 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad";
// handle inplace map
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
ctx.UpdatePlainOutputs(
grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::OpMetaInfoHelper::GetKernelFn(kernel_map.at(op_type_)[1]))(&ctx);
ctx.AssignInplaceOutputs();
......@@ -443,7 +444,8 @@ RunCustomOpDoubleGradNode::operator()(
VLOG(7) << "Run Kernel of Grad Custom Op: " << name();
// handle inplace map
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
ctx.UpdatePlainOutputs(
grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::OpMetaInfoHelper::GetKernelFn(kernel_map.at(op_type_)[2]))(&ctx);
ctx.AssignInplaceOutputs();
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/phi_utils.h"
......@@ -52,87 +53,6 @@ DECLARE_string(tensor_operants_mode);
namespace paddle {
namespace framework {
namespace detail {
// dynamic lib load func
template <typename T>
static T* DynLoad(void* handle, std::string name) {
T* func = reinterpret_cast<T*>(dlsym(handle, name.c_str()));
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(
func,
platform::errors::NotFound(
"Failed to load dynamic operator library, error message(%s).",
errorno));
return func;
}
inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static bool IsOptionalVar(const std::string& var_name) {
std::string suffix = kOptionalSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static std::string NoGrad(const std::string& var_name,
bool is_double_grad = false) {
std::string suffix = kGradVarSuffix;
std::string new_out_suffix = kDoubleGradNewOutSuffix;
std::string tmp_var_name(var_name);
if (is_double_grad &&
(tmp_var_name.rfind(new_out_suffix) != std::string::npos)) {
tmp_var_name = tmp_var_name.substr(
0, tmp_var_name.size() - /*kDoubleGradNewOutSuffix length*/ 4);
}
return tmp_var_name.substr(0, tmp_var_name.size() - kGradVarSuffixSize);
}
inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}
inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}
static std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos,
std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
} // namespace detail
////////////////// Kernel Define ////////////////////
// custom op kernel call function define
static void RunKernelFunc(
const framework::ExecutionContext& ctx,
......@@ -355,7 +275,7 @@ static void RunKernelFunc(
}
// handle inplace map
kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map);
kernel_ctx.UpdatePlainOutputs(inputs, outputs, inplace_map);
func(&kernel_ctx);
kernel_ctx.AssignInplaceOutputs();
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/api/ext/op_meta_info.h"
namespace paddle {
namespace framework {
namespace detail {
// dynamic lib load func
template <typename T>
static T* DynLoad(void* handle, std::string name) {
T* func = reinterpret_cast<T*>(dlsym(handle, name.c_str()));
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(
func,
platform::errors::NotFound(
"Failed to load dynamic operator library, error message(%s).",
errorno));
return func;
}
inline static bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static bool IsOptionalVar(const std::string& var_name) {
std::string suffix = kOptionalSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static std::string NoGrad(const std::string& var_name,
bool is_double_grad = false) {
std::string suffix = kGradVarSuffix;
std::string new_out_suffix = kDoubleGradNewOutSuffix;
std::string tmp_var_name(var_name);
if (is_double_grad &&
(tmp_var_name.rfind(new_out_suffix) != std::string::npos)) {
tmp_var_name = tmp_var_name.substr(
0, tmp_var_name.size() - /*kDoubleGradNewOutSuffix length*/ 4);
}
return tmp_var_name.substr(0, tmp_var_name.size() - kGradVarSuffixSize);
}
inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
std::string suffix = kGradVarSuffix;
if (!is_double_grad) {
return var_name.rfind(suffix) != std::string::npos;
} else {
// for double grad cases, the X@GRAD is not a grad var, X@GRAD@GRAD is a
// grad var, here we remove a @GRAD suffix
return NoGrad(var_name).rfind(suffix) != std::string::npos;
}
}
inline static bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}
static std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos,
std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
} // namespace detail
} // namespace framework
} // namespace paddle
......@@ -56,7 +56,6 @@ extern PyTypeObject* g_cudapinnedplace_pytype;
extern PyTypeObject* g_customplace_pytype;
extern PyTypeObject* g_framework_tensor_pytype;
extern PyTypeObject* g_framework_lodtensorarray_pytype;
extern PyTypeObject* g_custom_op_kernel_ctx_pytype;
extern PyTypeObject* g_jit_function_pytype;
int TensorDtype2NumpyDtype(phi::DataType dtype) {
......@@ -432,6 +431,54 @@ std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) {
return result;
}
std::vector<float> CastPyArg2VectorOfFloat(PyObject* obj, size_t arg_pos) {
std::vector<float> result;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrConvertToFloat(&item)) {
result.emplace_back(static_cast<float>(PyFloat_AsDouble(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of float, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GET_ITEM(obj, i);
if (PyObject_CheckFloatOrConvertToFloat(&item)) {
result.emplace_back(static_cast<float>(PyFloat_AsDouble(item)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of float, but got %s at pos %d",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name,
i));
}
}
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckFloatOrConvertToFloat(&obj)) {
return {static_cast<float>(PyFloat_AsDouble(obj))};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"list of float, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
return result;
}
std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos) {
std::vector<std::vector<size_t>> result;
......@@ -602,19 +649,6 @@ std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
}
}
paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj,
ssize_t arg_pos) {
if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_custom_op_kernel_ctx_pytype))) {
return ::pybind11::handle(obj).cast<paddle::CustomOpKernelContext>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be CustomOpKernelContext, "
"but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}
PyObject* ToPyObject(bool value) {
if (value) {
Py_INCREF(Py_True);
......
......@@ -57,8 +57,6 @@ int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos);
size_t CastPyArg2AttrSize_t(PyObject* obj, ssize_t arg_pos);
float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos);
std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos);
paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj,
ssize_t arg_pos);
std::shared_ptr<imperative::VarBase> CastPyArg2VarBase(PyObject* obj,
ssize_t arg_pos);
std::vector<paddle::Tensor> CastPyArg2VectorOfTensor(PyObject* obj,
......@@ -70,6 +68,7 @@ std::vector<phi::DenseTensor> CastPyArg2VectorOfTensorBase(PyObject* obj,
std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos);
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos);
std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos);
std::vector<float> CastPyArg2VectorOfFloat(PyObject* obj, size_t arg_pos);
std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos);
framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
......
......@@ -464,7 +464,7 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
value.emplace_back((int64_t)PyLong_AsLongLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -481,7 +481,7 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
value.emplace_back((int64_t)PyLong_AsLongLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -498,7 +498,7 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
value.emplace_back((int64_t)PyLong_AsLongLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -512,7 +512,7 @@ std::vector<int64_t> CastPyArg2Longs(PyObject* obj,
} else if (obj == Py_None) {
return {};
} else if (PyObject_CheckLongOrToLong(&obj)) {
return {static_cast<int64_t>(PyLong_AsLong(obj))};
return {(int64_t)PyLong_AsLongLong(obj)};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......
......@@ -1013,70 +1013,6 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("_promote_types_if_complex_exists",
&paddle::framework::PromoteTypesIfComplexExists);
py::class_<paddle::CustomOpKernelContext> custom_op_kernel_ctx(
m, "CustomOpKernelContext", R"DOC()DOC");
g_custom_op_kernel_ctx_pytype =
reinterpret_cast<PyTypeObject *>(custom_op_kernel_ctx.ptr());
custom_op_kernel_ctx.def(py::init<>())
.def("add_inputs",
[](paddle::CustomOpKernelContext &self, const py::handle &input) {
PyObject *obj = input.ptr();
if (PyList_Check(obj) || PyTuple_Check(obj)) {
self.EmplaceBackInputs(
std::move(CastPyArg2VectorOfTensor(obj, 1)));
} else if (obj == Py_None) {
// Check optional Tensor, use one un-initialized tensor to
// indicate both Tensor and vector<Tensor> inputs
self.EmplaceBackInput(std::move(paddle::Tensor()));
} else {
self.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, 1)));
}
})
.def("add_outputs",
[](paddle::CustomOpKernelContext &self, py::handle &outputs) {
PyObject *obj = outputs.ptr();
if (PyList_Check(obj) || PyTuple_Check(obj)) {
self.EmplaceBackOutputs(
std::move(CastPyArg2VectorOfTensor(obj, 1)));
} else {
self.EmplaceBackOutput(std::move(CastPyArg2Tensor(obj, 1)));
}
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self, bool attr) {
self.EmplaceBackAttr(attr);
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self, int attr) {
self.EmplaceBackAttr(attr);
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self, float attr) {
self.EmplaceBackAttr(attr);
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self, int64_t attr) {
self.EmplaceBackAttr(attr);
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self, const std::string &attr) {
self.EmplaceBackAttr(attr);
})
.def("add_attr",
[](paddle::CustomOpKernelContext &self,
const std::vector<int> &attr) { self.EmplaceBackAttr(attr); })
.def("add_attr",
[](paddle::CustomOpKernelContext &self,
const std::vector<float> &attr) { self.EmplaceBackAttr(attr); })
.def("add_attr",
[](paddle::CustomOpKernelContext &self,
const std::vector<int64_t> &attr) { self.EmplaceBackAttr(attr); })
.def("add_attr",
[](paddle::CustomOpKernelContext &self,
const std::vector<std::string> &attr) {
self.EmplaceBackAttr(attr);
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
......
......@@ -119,6 +119,7 @@ class PADDLE_API CustomOpKernelContext {
const Tensor& InputAt(size_t idx) const;
std::vector<Tensor> InputsBetween(size_t start, size_t end) const;
Tensor& MutableInputAt(size_t idx);
std::vector<Tensor>* AllMutableInput();
paddle::optional<Tensor> OptionalInputAt(size_t idx);
paddle::optional<std::vector<Tensor>> OptionalInputsBetween(size_t start,
size_t end);
......@@ -144,13 +145,18 @@ class PADDLE_API CustomOpKernelContext {
}
// handle inplace map
void MapPlainOutputs(
void ConstructInplaceIndex(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map);
void UpdatePlainOutputs(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map);
void AssignInplaceOutputs();
std::vector<Tensor*>* AllMutablePlainOutput();
std::unordered_map<size_t, size_t> GetInplaceTensorMap();
std::unordered_map<size_t, size_t> GetInplaceIndexMap();
std::unordered_map<size_t, size_t> GetInplaceReverseIndexMap();
private:
// TODO(chenweihang): replaced be SmallVector
......@@ -159,7 +165,10 @@ class PADDLE_API CustomOpKernelContext {
std::vector<paddle::any> attrs_;
// handle inplace map
std::vector<Tensor*> plain_outputs_;
std::unordered_map<size_t, size_t> inplace_tensor_map_;
// {input: output}
std::unordered_map<size_t, size_t> inplace_idx_map_;
// {output: input}
std::unordered_map<size_t, size_t> inplace_reverse_idx_map_;
std::vector<std::pair<size_t, size_t>> input_range_;
std::vector<std::pair<size_t, size_t>> output_range_;
......
......@@ -103,6 +103,10 @@ Tensor& CustomOpKernelContext::MutableInputAt(size_t idx) {
return inputs_.at(idx);
}
std::vector<Tensor>* CustomOpKernelContext::AllMutableInput() {
return &inputs_;
}
paddle::optional<Tensor> CustomOpKernelContext::OptionalInputAt(size_t idx) {
if (!inputs_.at(idx).is_initialized()) {
return paddle::none;
......@@ -156,13 +160,15 @@ const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt(
return output_range_.at(idx);
}
// handle inplace mechanism
// Find out non-inplace output tensors.
// TODO(HongyuJia): Add cache for inplace_tensor_map_ to optimize performance
void CustomOpKernelContext::MapPlainOutputs(
void CustomOpKernelContext::ConstructInplaceIndex(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map) {
// Cache inplace indices.
if (inplace_map.empty() || !inplace_idx_map_.empty()) {
VLOG(4) << "Custom opertor ConstructInplaceIndex no need to recompute.";
return;
}
for (size_t in_idx = 0; in_idx < inputs.size(); ++in_idx) {
auto& input = inputs[in_idx];
if (inplace_map.find(input) == inplace_map.end()) {
......@@ -175,15 +181,26 @@ void CustomOpKernelContext::MapPlainOutputs(
"the input of `Inplace` again and make "
"sure you registered your op accurately. ",
input));
inplace_tensor_map_[in_idx] = distance(outputs.begin(), out_iter);
size_t out_idx = distance(outputs.begin(), out_iter);
inplace_idx_map_[in_idx] = out_idx;
inplace_reverse_idx_map_[out_idx] = in_idx;
}
VLOG(4) << "Custom opertor update inplace input-output map successfully.";
}
// Find out non-inplace output tensors.
void CustomOpKernelContext::UpdatePlainOutputs(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map) {
// Cache plain outputs vector.
if (!plain_outputs_.empty()) {
VLOG(4) << "Custom opertor UpdatePlainOutputs no need to recompute.";
return;
}
ConstructInplaceIndex(inputs, outputs, inplace_map);
for (size_t i = 0; i < outputs.size(); ++i) {
if (std::any_of(
inplace_tensor_map_.begin(),
inplace_tensor_map_.end(),
[i](std::unordered_map<size_t, size_t>::const_reference pair) {
return pair.second == i;
})) {
if (inplace_reverse_idx_map_.find(i) != inplace_reverse_idx_map_.end()) {
continue;
}
size_t output_start_idx = output_range_[i].first;
......@@ -192,11 +209,12 @@ void CustomOpKernelContext::MapPlainOutputs(
plain_outputs_.push_back(&outputs_[idx]);
}
}
VLOG(4) << "Custom opertor update inplace input-output map successfully.";
VLOG(4) << "Custom opertor update plain outputs map successfully.";
}
// Assign input tensor to inplace output tensors.
void CustomOpKernelContext::AssignInplaceOutputs() {
for (auto pair : inplace_tensor_map_) {
for (auto pair : inplace_idx_map_) {
size_t in_start_idx = input_range_[pair.first].first;
size_t in_end_idx = input_range_[pair.first].second;
size_t out_start_idx = output_range_[pair.second].first;
......@@ -213,15 +231,21 @@ void CustomOpKernelContext::AssignInplaceOutputs() {
}
VLOG(4) << "Custom opertor update inplace input-output tensor "
"successfully. Update map size = "
<< inplace_tensor_map_.size();
<< inplace_idx_map_.size();
}
}
std::vector<Tensor*>* CustomOpKernelContext::AllMutablePlainOutput() {
return &plain_outputs_;
}
std::unordered_map<size_t, size_t> CustomOpKernelContext::GetInplaceIndexMap() {
return inplace_idx_map_;
}
std::unordered_map<size_t, size_t>
CustomOpKernelContext::GetInplaceTensorMap() {
return inplace_tensor_map_;
CustomOpKernelContext::GetInplaceReverseIndexMap() {
return inplace_reverse_idx_map_;
}
////////////////////// Op Meta Info //////////////////////
......
......@@ -1042,7 +1042,9 @@ def _gen_output_content(
# ' ' * tab space * tab number
indent = ' ' * 4 * 2
inplace_idx = {v: k for k, v in inplace_reverse_idx.items()}
dynamic_content = ""
dynamic_content = f"""
{indent}res = []
{indent}start_idx = 0"""
static_content = f"""
{indent}ins = {{}}
{indent}ins_map = {ins_map}
......@@ -1065,10 +1067,11 @@ def _gen_output_content(
lower_in_names = in_names[in_idx].split("@")[0].lower()
dynamic_content += f"""
{indent}if {lower_in_names} is not None:
{indent} outs['{out_name}'] = [core.eager.Tensor() for _ in range(len({lower_in_names}))]
{indent} res.append(outs[start_idx: start_idx + len({lower_in_names})])
{indent} start_idx += len({lower_in_names})
{indent}else:
{indent} outs['{out_name}'] = core.eager.Tensor()
{indent}ctx.add_outputs(outs['{out_name}'])"""
{indent} res.append(None)
{indent} start_idx += 1"""
static_content += f"""
{indent}if {lower_in_names} is not None:
{indent} outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]"""
......@@ -1077,8 +1080,8 @@ def _gen_output_content(
): # inplace vector<Tensor> output case
lower_in_names = in_names[in_idx].split("@")[0].lower()
dynamic_content += f"""
{indent}outs['{out_name}'] = [core.eager.Tensor() for _ in range(len({lower_in_names}))]
{indent}ctx.add_outputs(outs['{out_name}'])"""
{indent}res.append(outs[start_idx: start_idx + len({lower_in_names})])
{indent}start_idx += len({lower_in_names})"""
static_content += f"""
{indent}outs['{out_name}'] = [helper.create_variable(dtype='float32') for _ in range(len({lower_in_names}))]"""
elif (
......@@ -1086,21 +1089,22 @@ def _gen_output_content(
): # inplace optional Tensor output case, handle inplace None input
lower_in_names = in_names[in_idx].split("@")[0].lower()
dynamic_content += f"""
{indent}outs['{out_name}'] = core.eager.Tensor()
{indent}ctx.add_outputs(outs['{out_name}'])"""
{indent}if {lower_in_names} is not None:
{indent} res.append(outs[start_idx])
{indent}else:
{indent} res.append(None)
{indent}start_idx += 1"""
static_content += f"""
{indent}if {lower_in_names} is not None:
{indent} outs['{out_name}'] = helper.create_variable(dtype='float32')"""
else: # general/inplace Tensor output case
dynamic_content += f"""
{indent}outs['{out_name}'] = core.eager.Tensor()
{indent}ctx.add_outputs(outs['{out_name}'])"""
{indent}res.append(outs[start_idx])
{indent}start_idx += 1"""
static_content += f"""
{indent}outs['{out_name}'] = helper.create_variable(dtype='float32')"""
dynamic_content += f"""
{indent}core.eager._run_custom_op(ctx, "{op_name}", True)
{indent}res = [outs[out_name] if isinstance(outs[out_name], list) or outs[out_name]._is_initialized() else None for out_name in outs_list]
{indent}return res[0] if len(res)==1 else res"""
static_content += f"""
......@@ -1134,7 +1138,7 @@ def _custom_api_content(op_name):
API_TEMPLATE = textwrap.dedent(
"""
import paddle.fluid.core as core
from paddle.fluid.core import Tensor, CustomOpKernelContext
from paddle.fluid.core import Tensor
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -1146,11 +1150,7 @@ def _custom_api_content(op_name):
# The output variable's dtype use default value 'float32',
# and the actual dtype of output variable will be inferred in runtime.
if in_dygraph_mode():
ctx = CustomOpKernelContext()
for i in {in_names}:
ctx.add_inputs(i)
for j in {attr_names}:
ctx.add_attr(j)
outs = core.eager._run_custom_op("{op_name}", {params_list})
{dynamic_content}
else:
{static_content}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册