未验证 提交 d2122c6f 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Inplace] Add customOP inplace check (#51844)

上级 e35afed7
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "paddle/fluid/eager/custom_operator/custom_operator_node.h" #include "paddle/fluid/eager/custom_operator/custom_operator_node.h"
#include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -34,17 +33,12 @@ static void ConstructFwdAndBwdMap( ...@@ -34,17 +33,12 @@ static void ConstructFwdAndBwdMap(
} }
VLOG(7) << "Construct DoubleGrad's CustomEdgesSlotMap "; VLOG(7) << "Construct DoubleGrad's CustomEdgesSlotMap ";
auto inputs_names = auto inputs_names = paddle::OpMetaInfoHelper::GetInputs(vec_map[1]);
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[1]); auto outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[1]);
auto outputs_names = auto attrs_names = paddle::OpMetaInfoHelper::GetAttrs(vec_map[1]);
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[1]); auto grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[2]);
auto attrs_names = paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[1]); auto grad_inputs_names = paddle::OpMetaInfoHelper::GetInputs(vec_map[2]);
auto grad_outputs_names = auto grad_attrs_names = paddle::OpMetaInfoHelper::GetAttrs(vec_map[2]);
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]);
auto grad_inputs_names =
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]);
auto grad_attrs_names =
paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[2]);
std::vector<std::unordered_map<int, int>> res(5); std::vector<std::unordered_map<int, int>> res(5);
in_out_map[op_type].push_back(res); in_out_map[op_type].push_back(res);
// Prepare pos map for grad_outputs // Prepare pos map for grad_outputs
...@@ -170,13 +164,12 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -170,13 +164,12 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
bool create_graph, bool create_graph,
bool is_new_grad) { // NOLINT bool is_new_grad) { // NOLINT
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( auto grad_inputs_name = paddle::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs( auto grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
const auto& grad_inplace_map = paddle::OpMetaInfoHelper::GetInplaceMap(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();
...@@ -240,9 +233,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -240,9 +233,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
} }
VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad";
// handle inplace map
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map); ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn( (*paddle::OpMetaInfoHelper::GetKernelFn(kernel_map.at(op_type_)[1]))(&ctx);
kernel_map.at(op_type_)[1]))(&ctx);
ctx.AssignInplaceOutputs(); ctx.AssignInplaceOutputs();
VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
...@@ -333,8 +326,8 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -333,8 +326,8 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
ctx.InputRangeAt(it->first).second)); ctx.InputRangeAt(it->first).second));
} }
auto attrs_names = paddle::framework::OpMetaInfoHelper::GetAttrs( auto attrs_names =
meta_info_map.at(op_type_)[2]); paddle::OpMetaInfoHelper::GetAttrs(meta_info_map.at(op_type_)[2]);
std::vector<paddle::any> attrs(attrs_names.size()); std::vector<paddle::any> attrs(attrs_names.size());
// Prepare attrs for Grad node // Prepare attrs for Grad node
for (auto it = slot_map[1][4].begin(); it != slot_map[1][4].end(); it++) { for (auto it = slot_map[1][4].begin(); it != slot_map[1][4].end(); it++) {
...@@ -357,12 +350,10 @@ RunCustomOpDoubleGradNode::operator()( ...@@ -357,12 +350,10 @@ RunCustomOpDoubleGradNode::operator()(
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto meta_info_map = egr::Controller::Instance().GetOpMetaInfoMap();
const auto& vec_map = meta_info_map.at(op_type_); const auto& vec_map = meta_info_map.at(op_type_);
auto grad_inputs_name = auto grad_inputs_name = paddle::OpMetaInfoHelper::GetInputs(vec_map[2]);
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]); auto grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[2]);
auto grad_outputs_names =
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]);
const auto& grad_inplace_map = const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(vec_map[2]); paddle::OpMetaInfoHelper::GetInplaceMap(vec_map[2]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();
...@@ -429,9 +420,9 @@ RunCustomOpDoubleGradNode::operator()( ...@@ -429,9 +420,9 @@ RunCustomOpDoubleGradNode::operator()(
} }
VLOG(7) << "Run Kernel of Grad Custom Op: " << name(); VLOG(7) << "Run Kernel of Grad Custom Op: " << name();
// handle inplace map
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map); ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn( (*paddle::OpMetaInfoHelper::GetKernelFn(kernel_map.at(op_type_)[2]))(&ctx);
kernel_map.at(op_type_)[2]))(&ctx);
ctx.AssignInplaceOutputs(); ctx.AssignInplaceOutputs();
return outs; return outs;
......
...@@ -28,7 +28,6 @@ limitations under the License. */ ...@@ -28,7 +28,6 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
...@@ -285,7 +284,7 @@ static void RunKernelFunc( ...@@ -285,7 +284,7 @@ static void RunKernelFunc(
VLOG(4) << "Initialize phi tensor operants successfully"; VLOG(4) << "Initialize phi tensor operants successfully";
} }
// handle inplace case // handle inplace map
kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map); kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map);
func(&kernel_ctx); func(&kernel_ctx);
kernel_ctx.AssignInplaceOutputs(); kernel_ctx.AssignInplaceOutputs();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/api/ext/op_meta_info.h"
namespace paddle {
namespace framework {
class OpMetaInfoHelper {
public:
static const std::string& GetOpName(const paddle::OpMetaInfo& info) {
return info.name_;
}
static const std::vector<std::string>& GetInputs(
const paddle::OpMetaInfo& info) {
return info.inputs_;
}
static const std::vector<std::string>& GetOutputs(
const paddle::OpMetaInfo& info) {
return info.outputs_;
}
static const std::vector<std::string>& GetAttrs(
const paddle::OpMetaInfo& info) {
return info.attrs_;
}
static const std::unordered_map<std::string, std::string>& GetInplaceMap(
const paddle::OpMetaInfo& info) {
return info.inplace_map_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_;
}
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info) {
return info.infer_shape_fn_;
}
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info) {
return info.infer_dtype_fn_;
}
};
} // namespace framework
} // namespace paddle
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/generic_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h" #include "paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h"
#include "paddle/phi/api/ext/op_meta_info.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -48,7 +48,7 @@ class CustomPluginCreater : public OpConverter { ...@@ -48,7 +48,7 @@ class CustomPluginCreater : public OpConverter {
auto &op_info = meta_info_map.at(op_desc.Type()).front(); auto &op_info = meta_info_map.at(op_desc.Type()).front();
// set inputs // set inputs
auto &op_input_names = framework::OpMetaInfoHelper::GetInputs(op_info); auto &op_input_names = OpMetaInfoHelper::GetInputs(op_info);
for (auto &param_name : op_input_names) { for (auto &param_name : op_input_names) {
for (auto &arg_name : op_desc.Input(param_name)) { for (auto &arg_name : op_desc.Input(param_name)) {
inputs.push_back(engine_->GetITensor(arg_name)); inputs.push_back(engine_->GetITensor(arg_name));
...@@ -60,7 +60,7 @@ class CustomPluginCreater : public OpConverter { ...@@ -60,7 +60,7 @@ class CustomPluginCreater : public OpConverter {
// set attrs // set attrs
std::vector<nvinfer1::PluginField> plugindatas; std::vector<nvinfer1::PluginField> plugindatas;
auto &op_attrs_names = framework::OpMetaInfoHelper::GetAttrs(op_info); auto &op_attrs_names = OpMetaInfoHelper::GetAttrs(op_info);
auto &attrs = op_desc.GetAttrMap(); auto &attrs = op_desc.GetAttrMap();
std::list<int> int_attrs; std::list<int> int_attrs;
...@@ -147,7 +147,7 @@ class CustomPluginCreater : public OpConverter { ...@@ -147,7 +147,7 @@ class CustomPluginCreater : public OpConverter {
CHECK(layer); CHECK(layer);
// set outputs // set outputs
auto &op_output_names = framework::OpMetaInfoHelper::GetOutputs(op_info); auto &op_output_names = OpMetaInfoHelper::GetOutputs(op_info);
std::vector<std::string> output_names; std::vector<std::string> output_names;
for (auto &param_name : op_output_names) { for (auto &param_name : op_output_names) {
for (auto &arg_name : op_desc.Output(param_name)) for (auto &arg_name : op_desc.Output(param_name))
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h" #include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
......
...@@ -32,7 +32,6 @@ typedef SSIZE_T ssize_t; ...@@ -32,7 +32,6 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
...@@ -327,18 +326,12 @@ static void ConstructFwdAndBwdMap( ...@@ -327,18 +326,12 @@ static void ConstructFwdAndBwdMap(
return; return;
} else { } else {
VLOG(7) << "Construct CustomEdgesSlotMap "; VLOG(7) << "Construct CustomEdgesSlotMap ";
auto inputs_names = auto inputs_names = paddle::OpMetaInfoHelper::GetInputs(vec_map[0]);
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[0]); auto outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[0]);
auto outputs_names = auto attrs_names = paddle::OpMetaInfoHelper::GetAttrs(vec_map[0]);
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[0]); auto grad_outputs_names = paddle::OpMetaInfoHelper::GetOutputs(vec_map[1]);
auto attrs_names = auto grad_inputs_names = paddle::OpMetaInfoHelper::GetInputs(vec_map[1]);
paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[0]); auto grad_attrs_names = paddle::OpMetaInfoHelper::GetAttrs(vec_map[1]);
auto grad_outputs_names =
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[1]);
auto grad_inputs_names =
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[1]);
auto grad_attrs_names =
paddle::framework::OpMetaInfoHelper::GetAttrs(vec_map[1]);
std::vector<std::unordered_map<int, int>> res(5); std::vector<std::unordered_map<int, int>> res(5);
in_out_map.insert({op_type, {res}}); in_out_map.insert({op_type, {res}});
...@@ -525,23 +518,21 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -525,23 +518,21 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
"sure you registered your op first and try again. ", "sure you registered your op first and try again. ",
op_type)); op_type));
VLOG(7) << "Run Kernel of Custom Op: " << op_type; VLOG(7) << "Run Kernel of Custom Op: " << op_type;
std::vector<paddle::any> res_attrs = std::vector<paddle::any> res_attrs = CastAttrsToTargetType(
CastAttrsToTargetType(ctx.Attrs(), ctx.Attrs(),
paddle::framework::OpMetaInfoHelper::GetAttrs( paddle::OpMetaInfoHelper::GetAttrs(meta_info_map.at(op_type)[0]));
meta_info_map.at(op_type)[0]));
ctx.EmplaceBackAttrs(res_attrs); ctx.EmplaceBackAttrs(res_attrs);
const auto& vec_map = meta_info_map.at(op_type); const auto& vec_map = meta_info_map.at(op_type);
// handle inplace case const auto& inputs =
const auto& inputs = paddle::framework::OpMetaInfoHelper::GetInputs( paddle::OpMetaInfoHelper::GetInputs(meta_info_map.at(op_type)[0]);
meta_info_map.at(op_type)[0]); const auto& outputs =
const auto& outputs = paddle::framework::OpMetaInfoHelper::GetOutputs( paddle::OpMetaInfoHelper::GetOutputs(meta_info_map.at(op_type)[0]);
meta_info_map.at(op_type)[0]);
const auto& inplace_map = const auto& inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap( paddle::OpMetaInfoHelper::GetInplaceMap(meta_info_map.at(op_type)[0]);
meta_info_map.at(op_type)[0]); // handle inplace map
ctx.MapPlainOutputs(inputs, outputs, inplace_map); ctx.MapPlainOutputs(inputs, outputs, inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); (*paddle::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx);
ctx.AssignInplaceOutputs(); ctx.AssignInplaceOutputs();
VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
...@@ -569,7 +560,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -569,7 +560,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
trace_backward, &(ins_auto_grad_metas[i])); trace_backward, &(ins_auto_grad_metas[i]));
} }
// handle inplace case // handle inplace map
for (size_t i = 0; i < ctx.InputRange().size(); i++) { for (size_t i = 0; i < ctx.InputRange().size(); i++) {
if (inplace_map.find(inputs[i]) != inplace_map.end()) { if (inplace_map.find(inputs[i]) != inplace_map.end()) {
size_t input_size = size_t input_size =
...@@ -653,8 +644,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -653,8 +644,8 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
ctx.InputRangeAt(it->first).second)); ctx.InputRangeAt(it->first).second));
} }
auto attrs_names = paddle::framework::OpMetaInfoHelper::GetAttrs( auto attrs_names =
meta_info_map.at(op_type)[1]); paddle::OpMetaInfoHelper::GetAttrs(meta_info_map.at(op_type)[1]);
std::vector<paddle::any> attrs(attrs_names.size()); std::vector<paddle::any> attrs(attrs_names.size());
// Prepare attrs for Grad node // Prepare attrs for Grad node
for (auto it = slot_map[0][4].begin(); it != slot_map[0][4].end(); it++) { for (auto it = slot_map[0][4].begin(); it != slot_map[0][4].end(); it++) {
......
...@@ -33,10 +33,8 @@ limitations under the License. */ ...@@ -33,10 +33,8 @@ limitations under the License. */
*/ */
namespace paddle { namespace paddle {
namespace framework {
class PADDLE_API OpMetaInfoHelper;
} // namespace framework
class PADDLE_API OpMetaInfoHelper;
using Tensor = paddle::Tensor; using Tensor = paddle::Tensor;
///////////////// Util Marco Define //////////////// ///////////////// Util Marco Define ////////////////
...@@ -130,7 +128,7 @@ class PADDLE_API CustomOpKernelContext { ...@@ -130,7 +128,7 @@ class PADDLE_API CustomOpKernelContext {
} }
} }
// handle inplace case // handle inplace map
void MapPlainOutputs( void MapPlainOutputs(
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
...@@ -144,7 +142,7 @@ class PADDLE_API CustomOpKernelContext { ...@@ -144,7 +142,7 @@ class PADDLE_API CustomOpKernelContext {
std::vector<Tensor> inputs_; std::vector<Tensor> inputs_;
std::vector<Tensor> outputs_; std::vector<Tensor> outputs_;
std::vector<paddle::any> attrs_; std::vector<paddle::any> attrs_;
// handle inplace case // handle inplace map
std::vector<Tensor*> plain_outputs_; std::vector<Tensor*> plain_outputs_;
std::unordered_map<size_t, size_t> inplace_tensor_map_; std::unordered_map<size_t, size_t> inplace_tensor_map_;
...@@ -589,7 +587,7 @@ class PADDLE_API OpMetaInfo { ...@@ -589,7 +587,7 @@ class PADDLE_API OpMetaInfo {
OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func);
private: private:
friend class framework::OpMetaInfoHelper; friend class OpMetaInfoHelper;
// 1. desc info // 1. desc info
std::string name_; std::string name_;
...@@ -603,6 +601,39 @@ class PADDLE_API OpMetaInfo { ...@@ -603,6 +601,39 @@ class PADDLE_API OpMetaInfo {
InferDtypeFunc infer_dtype_fn_{nullptr}; InferDtypeFunc infer_dtype_fn_{nullptr};
}; };
//////////////// Op Meta Info Helper /////////////////
class OpMetaInfoHelper {
public:
static const std::string& GetOpName(const paddle::OpMetaInfo& info) {
return info.name_;
}
static const std::vector<std::string>& GetInputs(
const paddle::OpMetaInfo& info) {
return info.inputs_;
}
static const std::vector<std::string>& GetOutputs(
const paddle::OpMetaInfo& info) {
return info.outputs_;
}
static const std::vector<std::string>& GetAttrs(
const paddle::OpMetaInfo& info) {
return info.attrs_;
}
static const std::unordered_map<std::string, std::string>& GetInplaceMap(
const paddle::OpMetaInfo& info) {
return info.inplace_map_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_;
}
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info) {
return info.infer_shape_fn_;
}
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info) {
return info.infer_dtype_fn_;
}
};
//////////////// Op Meta Info Map ///////////////// //////////////// Op Meta Info Map /////////////////
class PADDLE_API OpMetaInfoMap { class PADDLE_API OpMetaInfoMap {
......
...@@ -299,6 +299,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) { ...@@ -299,6 +299,27 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInplaceMap( OpMetaInfoBuilder& OpMetaInfoBuilder::SetInplaceMap(
std::unordered_map<std::string, std::string>&& inplace_map) { std::unordered_map<std::string, std::string>&& inplace_map) {
const std::vector<std::string>& inputs =
OpMetaInfoHelper::GetInputs(*info_ptr_);
const std::vector<std::string>& outputs =
OpMetaInfoHelper::GetOutputs(*info_ptr_);
for (const auto& pair : inplace_map) {
PADDLE_ENFORCE(
std::find(inputs.begin(), inputs.end(), pair.first) != inputs.cend(),
phi::errors::PreconditionNotMet(
"The register of operator %s's `SetInplaceMap` failed. "
"Please make sure: 1. Call `Inputs` and `Outputs` before "
"`SetInplaceMap`; 2. The keys of inplace_map are inside `Inputs`",
name_));
PADDLE_ENFORCE(std::find(outputs.begin(), outputs.end(), pair.second) !=
outputs.cend(),
phi::errors::PreconditionNotMet(
"The register of operator %s's `SetInplaceMap` failed. "
"Please make sure: 1. Call `Inputs` and `Outputs` "
"before `SetInplaceMap`; 2. The values of inplace_map "
"are inside `Outputs`",
name_));
}
info_ptr_->SetInplaceMap( info_ptr_->SetInplaceMap(
std::forward<std::unordered_map<std::string, std::string>>(inplace_map)); std::forward<std::unordered_map<std::string, std::string>>(inplace_map));
return *this; return *this;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册