未验证 提交 7ee31a96 编写于 作者: C Chen Weihang 提交者: GitHub

[Perf] Optimize dygraph scheduling performance (#41696)

* split phi and fluid infermeta context

* resolve conflict

* fix type error

* optimize scheduling perf

* spec small vector size

* replace all grad var name

* fix test failed

* move init defalut signature

* polish details

* polish details

* fix no init bug

* init sig for tests

* add init sig for infer

* fix infrt error

* fix infrt failed

* fix kunlun error

* fix infrt failed
上级 b5d9c31c
...@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) { ...@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_lod(meta_tensor); share_lod(meta_tensor);
} }
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
const std::string& op_type) { int index = compat_inputs_.size();
compat_inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
int index = compat_outputs_.size();
compat_outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = compat_inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
compat_inputs_.insert(compat_inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void CompatInferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs) {
int index = compat_outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
compat_outputs_.insert(compat_outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
return compat_inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&>
CompatInferMetaContext::OptionalInputAt(size_t idx) const {
const auto& input = compat_inputs_.at(idx);
return input.initialized()
? paddle::optional<const phi::MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
}
paddle::optional<const std::vector<const phi::MetaTensor*>>
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = compat_inputs_.at(start);
if (first.initialized()) {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(
paddle::none);
}
phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
auto& out = compat_outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
size_t start, size_t end) {
std::vector<phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& out = compat_outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args // 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound( arg_map_fn, platform::errors::NotFound(
...@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context // 2. build infermeta context
phi::InferMetaContext infer_meta_context( CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args); auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args); auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args); auto& output_names = std::get<2>(signature.args);
auto kernels_map = const auto& args_def =
phi::KernelFactory::Instance().SelectKernelMap(signature.name); phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
if (kernels_map.size() == 0) { const auto& attr_defs = args_def.attribute_defs();
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) { for (auto& in_name : input_names) {
if (ctx->HasInputs(in_name)) { if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name); auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
if (input_var.size() == 1) { if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput( infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime())); std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
} else { } else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs; paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
inputs.reserve(input_var.size()); inputs;
for (const auto& in : input_var) { for (const auto& in : input_var) {
inputs.push_back( inputs.emplace_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime())); std::move(CompatMetaTensor(in, ctx->IsRuntime())));
} }
infer_meta_context.EmplaceBackInputs(std::move(inputs)); infer_meta_context.EmplaceBackInputs(std::move(inputs));
} }
} else { } else {
infer_meta_context.EmplaceBackInput({nullptr}); infer_meta_context.EmplaceBackInput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
} }
VLOG(6) << "BuildInferMetaContext: Done inputs";
auto attr_reader = ctx->Attrs(); auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i]; auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) {
// When attr is a vector_tensor or tensor, transform it to IntArray // When attr is a vector_tensor or tensor, transform it to IntArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for IntArray // If is in runtime, we will get tensor's value for IntArray
// and push it into attrs // and push it into attrs
...@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name)); attr_name));
} }
} else if (ctx->HasInput(attr_name)) { } else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name); auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
if (infershape_input.size() == 1) { if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
...@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// convert from data // convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp); auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>(); int32_t val_int = val.template to<int32_t>();
...@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} }
VLOG(6) << "BuildInferMetaContext: Done attrs";
for (auto& out_name : output_names) { for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name, true)) { if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name); auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
if (output_var.size() == 1) { if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>( infer_meta_context.EmplaceBackOutput(
output_var[0], ctx->IsRuntime())); std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
} else { } else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs; paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs.reserve(output_var.size()); outputs;
for (const auto& out : output_var) { for (const auto& out : output_var) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
if (BOOST_GET_CONST(Variable*, out)) { if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back( outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime())); std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue; continue;
} }
} else if (BOOST_GET_CONST(VarDesc*, out)) { } else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back( outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime())); std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue; continue;
} }
outputs.emplace_back(nullptr); outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
infer_meta_context.EmplaceBackOutputs(std::move(outputs)); infer_meta_context.EmplaceBackOutputs(std::move(outputs));
} }
} else { } else {
infer_meta_context.EmplaceBackOutput({nullptr}); infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
} }
VLOG(6) << "BuildInferMetaContext: Done outputs";
return infer_meta_context; return infer_meta_context;
} }
......
...@@ -18,38 +18,24 @@ limitations under the License. */ ...@@ -18,38 +18,24 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
namespace phi {
class InferMetaContext;
} // namespace phi
namespace paddle { namespace paddle {
namespace framework { namespace framework {
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
// TODO(chenweihang): Support TensorArray later // TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor { class CompatMetaTensor : public phi::MetaTensor {
public: public:
explicit CompatMetaTensor(bool is_runtime)
: is_runtime_(is_runtime), initialized_(false) {}
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {} : var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default; CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; CompatMetaTensor& operator=(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = default;
int64_t numel() const override; int64_t numel() const override;
...@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override; void share_meta(const MetaTensor& meta_tensor) override;
bool initialized() const override { return initialized_; };
private: private:
const LoD& GetRuntimeLoD() const { const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
...@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor {
InferShapeVarPtr var_; InferShapeVarPtr var_;
bool is_runtime_; bool is_runtime_;
bool initialized_{true};
};
// Note: In order to avoid using shared_ptr to manage MetaTensor in
// InferMetaContext, inherit and implement InferMetaContext separately
// for compatibility with fluid, shared_ptr will cause significant decrease
// in scheduling performance
class CompatInferMetaContext : public phi::InferMetaContext {
public:
CompatInferMetaContext() = default;
explicit CompatInferMetaContext(phi::MetaConfig config)
: phi::InferMetaContext(config) {}
void EmplaceBackInput(CompatMetaTensor input);
void EmplaceBackOutput(CompatMetaTensor output);
void EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs);
const phi::MetaTensor& InputAt(size_t idx) const override;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(
size_t idx) const override;
std::vector<const phi::MetaTensor*> InputsBetween(size_t start,
size_t end) const override;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const override;
phi::MetaTensor* MutableOutputAt(size_t idx) override;
std::vector<phi::MetaTensor*> MutableOutputBetween(size_t start,
size_t end) override;
virtual ~CompatInferMetaContext() = default;
private:
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
compat_inputs_;
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
compat_outputs_;
}; };
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { ...@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
} }
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const { const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
std::vector<InferShapeVarPtr> paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs( InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const { const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name); const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
......
...@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRunMKLDNNKernel() const override; bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override; GetInputVarPtrs(const std::string& name) const override;
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override; GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override; DDim GetInputDim(const std::string& name) const override;
......
...@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
} }
} }
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string &name) const override { GetInputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Inputs(name); const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(arg_names.size()); res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { [this](const std::string &name) {
...@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return res; return res;
} }
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string &name) const override { GetOutputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Outputs(name); const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(arg_names.size()); res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { [this](const std::string &name) {
......
...@@ -945,19 +945,19 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -945,19 +945,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override { GetInputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override { GetOutputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = OutputVars(name); const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
...@@ -1324,8 +1324,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1324,8 +1324,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< ", using_kernel_key:" << *kernel_type_.get(); << ", using_kernel_key:" << *kernel_type_.get();
auto try_pt_kernel_key = auto try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
if (!phi::KernelFactory::Instance().IsSelectKernelValid( if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
pt_kernel_name, try_pt_kernel_key)) { try_pt_kernel_key)) {
kernel_type_->library_type_ = expected_kernel_key_library_type; kernel_type_->library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel in static graph: " << type_ VLOG(3) << "modify XPU KP kernel in static graph: " << type_
<< " is failed " << *kernel_type_.get(); << " is failed " << *kernel_type_.get();
...@@ -2113,10 +2113,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -2113,10 +2113,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())( if (arg_map_fn_ == nullptr) {
arg_mapping_ctx); arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
}
return (*arg_map_fn_)(arg_mapping_ctx);
} }
Scope* OperatorWithKernel::PreparePhiData( Scope* OperatorWithKernel::PreparePhiData(
......
...@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_kp_kernel = false; mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_; mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
~KernelArgsNameMakerByOpProto() {} ~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override; const paddle::SmallVector<const char*>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override; const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override; const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature(); KernelSignature GetKernelSignature();
...@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
private: private:
const framework::proto::OpProto* op_proto_; const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_; paddle::SmallVector<const char*> input_names_;
paddle::SmallVector<std::string> output_names_; paddle::SmallVector<const char*> output_names_;
paddle::SmallVector<std::string> attr_names_; paddle::SmallVector<const char*> attr_names_;
}; };
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) { OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
...@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_xpu_place(expected_kernel_key.place_) || if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) { paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type() VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) { if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing NPU kernel: " << op.Type() VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) { if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing MLU kernel: " << op.Type() VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) { if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing IPU kernel: " << op.Type() VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_custom_place(expected_kernel_key.place_)) { if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << op.Type() << " kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey(); return phi::KernelKey();
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() { KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) { for (int i = 0; i < op_proto_->inputs_size(); ++i) {
auto& in = op_proto_->inputs()[i]; auto& in = op_proto_->inputs()[i];
auto& in_name = in.name(); auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name;
continue; continue;
} }
// If contains dispensable input, we should override the // If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir // OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) { if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel input: " << in_name; input_names_.emplace_back(in_name.c_str());
input_names_.emplace_back(in_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel inputs: ";
std::copy(input_names_.begin(), input_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return input_names_; return input_names_;
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetOutputArgsNames() { KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
for (int i = 0; i < op_proto_->outputs_size(); ++i) { for (int i = 0; i < op_proto_->outputs_size(); ++i) {
auto& out = op_proto_->outputs()[i]; auto& out = op_proto_->outputs()[i];
auto& out_name = out.name(); auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) { if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel output: " << out_name; output_names_.emplace_back(out_name.c_str());
output_names_.emplace_back(out_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel outputs: ";
std::copy(output_names_.begin(), output_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return output_names_; return output_names_;
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
for (int i = 0; i < op_proto_->attrs_size(); ++i) { for (int i = 0; i < op_proto_->attrs_size(); ++i) {
auto& attr = op_proto_->attrs()[i]; auto& attr = op_proto_->attrs()[i];
...@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" || attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" || attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") { attr_name == "op_device") {
VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name;
continue; continue;
} }
if ((attr.has_extra() && attr.extra()) || if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) { (attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel attribute: " << attr_name; attr_names_.emplace_back(attr_name.c_str());
attr_names_.emplace_back(attr_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel attributes: ";
std::copy(attr_names_.begin(), attr_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return attr_names_; return attr_names_;
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()), return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
GetInputArgsNames(), GetAttrsArgsNames(), GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames()); GetOutputArgsNames());
} }
...@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() { ...@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) && if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) { op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type; VLOG(10) << "Register `" << op_type << "` kernel signature:";
phi::DefaultKernelSignatureMap::Instance().Insert( phi::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature())); op_type, std::move(maker.GetKernelSignature()));
} }
......
...@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
class KernelArgsNameMaker { class KernelArgsNameMaker {
public: public:
virtual ~KernelArgsNameMaker() {} virtual ~KernelArgsNameMaker() {}
virtual const paddle::SmallVector<std::string>& GetInputArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetOutputArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetAttrsArgsNames() = 0;
}; };
void InitDefaultKernelSignatureMap(); void InitDefaultKernelSignatureMap();
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/small_vector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -106,10 +108,10 @@ class InferShapeContext { ...@@ -106,10 +108,10 @@ class InferShapeContext {
virtual bool IsRunMKLDNNKernel() const = 0; virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs( virtual paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string &name) const = 0; GetInputVarPtrs(const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs( virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string &name) const = 0; GetOutputVarPtrs(const std::string &name) const = 0;
protected: protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
......
...@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
} }
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override { GetInputVarPtrs(const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
res;
auto it = var_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_map_in_->end(), it, var_map_in_->end(),
...@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return res; return res;
} }
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<framework::InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override { GetOutputVarPtrs(const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; paddle::SmallVector<framework::InferShapeVarPtr,
phi::kOutputSmallVectorSize>
res;
auto it = var_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_map_out_->end(), it, var_map_out_->end(),
......
...@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel); ...@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static const phi::Kernel empty_kernel;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) { const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar(); return var->SharedVar();
...@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
ctx_(ctx), ctx_(ctx),
kernel_type_(kernel_type), kernel_type_(kernel_type),
func_(func), func_(func),
dev_ctx_(dev_ctx) {} dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, const phi::Kernel& pt_kernel,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
...@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr), func_(nullptr),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_phi_kernel_(true), run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature), pt_kernel_signature_(std::move(kernel_signature)),
pt_kernel_(pt_kernel) {} pt_kernel_(pt_kernel) {}
template <typename VarType> template <typename VarType>
...@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx); pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature; VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
...@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key = phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key); TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid( if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
pt_kernel_name, try_pt_kernel_key)) { try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type; expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key; << expected_kernel_key;
...@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_key); pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid() if (pt_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
...@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
// TODO(chenweihang): using CPUKernel when miss device kernel case return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, std::move(pt_kernel_signature), pt_kernel, dev_ctx);
pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
...@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op); FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key); pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) { if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel; << " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key,
pt_cpu_kernel, cpu_ctx); std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
} }
} }
} }
...@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl( ...@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl(
#endif #endif
} }
// TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) { if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs); HandleComplexGradToRealGrad<VarType>(outs);
} }
......
...@@ -154,7 +154,7 @@ class PreparedOp { ...@@ -154,7 +154,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
...@@ -206,7 +206,7 @@ class PreparedOp { ...@@ -206,7 +206,7 @@ class PreparedOp {
bool run_phi_kernel_{false}; bool run_phi_kernel_{false};
bool run_kp_kernel_{false}; bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_; framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_; const phi::Kernel& pt_kernel_;
}; };
const inline framework::Attribute& GetAttr( const inline framework::Attribute& GetAttr(
...@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext( ...@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext(
} }
} }
auto ins_vector = it->second; auto& ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size(); size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
...@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel, ...@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var = ins_vector[offset]; auto& var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var->Var()); const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) { if (tensor_in && tensor_in->IsInitialized()) {
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
......
...@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope( ...@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope(
status_is_cloned_ = true; status_is_cloned_ = true;
} else { } else {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
// TODO(wilber): we need to release memory occupied by weights. // TODO(wilber): we need to release memory occupied by weights.
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
status_is_cloned_ = false; status_is_cloned_ = false;
......
...@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init( ...@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init(
"The sub_scope should not be nullptr.")); "The sub_scope should not be nullptr."));
} else { } else {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
} }
......
...@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs); ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs)); ctx->HasInputs(framework::GradVarName(kOutputs));
auto pg_ig_names = ctx->Outputs(kXGRAD); auto pg_ig_names = ctx->Outputs(kXGRAD);
std::vector<framework::InferShapeVarPtr> in_var_ptrs = auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
ctx->GetInputVarPtrs(kX); auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(), PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of Inputs(X) must be the same as " "The size of Inputs(X) must be the same as "
......
...@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
context->ShareLoD("MultiLevelRois", "FpnRois"); context->ShareLoD("MultiLevelRois", "FpnRois");
} }
if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) { if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs = auto roi_inputs = context->GetInputVarPtrs("MultiLevelRois");
context->GetInputVarPtrs("MultiLevelRois"); auto score_inputs = context->GetInputVarPtrs("MultiLevelScores");
std::vector<framework::InferShapeVarPtr> score_inputs =
context->GetInputVarPtrs("MultiLevelScores");
for (size_t i = 0; i < roi_inputs.size(); ++i) { for (size_t i = 0; i < roi_inputs.size(); ++i) {
framework::Variable *roi_var = framework::Variable *roi_var =
BOOST_GET(framework::Variable *, roi_inputs[i]); BOOST_GET(framework::Variable *, roi_inputs[i]);
......
...@@ -60,6 +60,7 @@ limitations under the License. */ ...@@ -60,6 +60,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/uva_utils.h" #include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h" #include "paddle/phi/core/compat/type_defs.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) { ...@@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) {
*(imperative::AmpOperators::Instance().GetMutableAllowOps()), *(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps())); *(imperative::AmpOperators::Instance().GetMutableBlockOps()));
}) })
.def("_get_kernel_signature", .def(
[](imperative::Tracer &self, const std::string &type, "_get_kernel_signature",
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, [](imperative::Tracer &self, const std::string &type,
framework::AttributeMap attrs) { const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
// TODO(xiongkun): move this function outside of tracer. framework::AttributeMap attrs) {
auto ins_map = ConvertToNameTensorMap(ins); // TODO(xiongkun): move this function outside of tracer.
auto outs_map = ConvertToNameTensorMap(outs); auto ins_map = ConvertToNameTensorMap(ins);
{ auto outs_map = ConvertToNameTensorMap(outs);
auto to_vector = [](paddle::SmallVector<std::string> &vec) { {
return std::vector<std::string>(vec.begin(), vec.end()); auto input_to_vector =
}; [](paddle::SmallVector<const char *> &vec) {
auto ret = self.GetExpectedKernelSignature(type, ins_map, return std::vector<std::string>(vec.begin(), vec.end());
outs_map, attrs); };
auto kernelsig_ins = to_vector(std::get<0>(ret.args)); auto output_to_vector =
auto kernelsig_attrs = to_vector(std::get<1>(ret.args)); [](paddle::SmallVector<const char *> &vec) {
auto kernelsig_outs = to_vector(std::get<2>(ret.args)); return std::vector<std::string>(vec.begin(), vec.end());
return std::make_tuple(kernelsig_ins, kernelsig_attrs, };
kernelsig_outs); auto attr_to_vector = [](paddle::SmallVector<const char *> &vec) {
} return std::vector<std::string>(vec.begin(), vec.end());
}) };
auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs);
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args));
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args));
return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs);
}
})
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
...@@ -2941,6 +2941,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2941,6 +2941,8 @@ All parameter, weight, gradient are variables in Paddle.
framework::LoadOpMetaInfoAndRegisterOp(dso_name)); framework::LoadOpMetaInfoAndRegisterOp(dso_name));
}); });
m.def("init_devices", []() { framework::InitDevices(); }); m.def("init_devices", []() { framework::InitDevices(); });
m.def("init_default_kernel_signatures",
[]() { framework::InitDefaultKernelSignatureMap(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM); m.def("is_compiled_with_rocm", IsCompiledWithROCM);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/infrt/dialect/phi/data_type.h" #include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
namespace infrt { namespace infrt {
...@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels( ...@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels(
phi_kernel_desc.input_types.clear(); phi_kernel_desc.input_types.clear();
phi_kernel_desc.output_types.clear(); phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def(); phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg = const paddle::SmallVector<phi::TensorArgDef, phi::kInputSmallVectorSize>&
args_def.input_defs(); input_arg = args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg = const paddle::SmallVector<phi::TensorArgDef, phi::kOutputSmallVectorSize>&
args_def.output_defs(); output_arg = args_def.output_defs();
for (auto tensor_arg : input_arg) { for (auto tensor_arg : input_arg) {
phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg)); phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
} }
......
...@@ -91,6 +91,7 @@ using ValueVariantType = ...@@ -91,6 +91,7 @@ using ValueVariantType =
std::vector<::phi::DenseTensor*>, std::vector<::phi::DenseTensor*>,
paddle::experimental::ScalarBase<::phi::DenseTensor>, paddle::experimental::ScalarBase<::phi::DenseTensor>,
paddle::experimental::IntArrayBase<::phi::DenseTensor>, paddle::experimental::IntArrayBase<::phi::DenseTensor>,
std::vector<const ::phi::MetaTensor*>,
std::vector<::phi::MetaTensor*>, std::vector<::phi::MetaTensor*>,
::phi::MetaConfig, ::phi::MetaConfig,
paddle::experimental::Backend, paddle::experimental::Backend,
......
...@@ -19,45 +19,33 @@ limitations under the License. */ ...@@ -19,45 +19,33 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
inline std::string GradVarName(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
// tuple(input_names, attr_names, output_names) // tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>, using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
paddle::SmallVector<std::string>, paddle::SmallVector<const char*>,
paddle::SmallVector<std::string>>; paddle::SmallVector<const char*>>;
struct KernelSignature { struct KernelSignature {
std::string name; const char* name;
KernelArgsTuple args; KernelArgsTuple args;
KernelSignature() = default; KernelSignature() = default;
KernelSignature(std::string&& kernel_name, KernelSignature(const char* kernel_name,
paddle::SmallVector<std::string>&& inputs, paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<std::string>&& attrs, paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<std::string>&& outputs) paddle::SmallVector<const char*>&& outputs)
: name(std::move(kernel_name)), : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
args(std::make_tuple(inputs, attrs, outputs)) {} KernelSignature(const char* kernel_name,
KernelSignature(const std::string& kernel_name, const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<std::string>& inputs, const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<std::string>& attrs, const paddle::SmallVector<const char*>& outputs)
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
......
...@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { ...@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
} }
} }
std::string TransToPhiKernelName(const std::string& fluid_op_name) { const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name); return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
} }
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace phi { namespace phi {
std::string TransToPhiKernelName(const std::string& fluid_op_name); const std::string& TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& phi_kernel_name); const std::string& TransToFluidOpName(const std::string& phi_kernel_name);
Backend TransToPhiBackend(const phi::Place& place); Backend TransToPhiBackend(const phi::Place& place);
......
...@@ -26,6 +26,8 @@ limitations under the License. */ ...@@ -26,6 +26,8 @@ limitations under the License. */
namespace phi { namespace phi {
const static std::string deprecated_kernel_name = "deprecated"; // NOLINT
const std::unordered_set<std::string> standard_kernel_suffixs({ const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel "sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op "raw" // fallback kernel of origfinal fluid op
...@@ -134,9 +136,9 @@ class OpUtilsMap { ...@@ -134,9 +136,9 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
} }
std::string GetBaseKernelName(const std::string& op_type) const { const std::string& GetBaseKernelName(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) { if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return "deprecated"; return deprecated_kernel_name;
} }
auto it = base_kernel_name_map_.find(op_type); auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) { if (it == base_kernel_name_map_.end()) {
...@@ -150,7 +152,7 @@ class OpUtilsMap { ...@@ -150,7 +152,7 @@ class OpUtilsMap {
auto it = arg_mapping_fn_map_.find(op_type); auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) { if (it == arg_mapping_fn_map_.end()) {
auto func = auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature { [&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type); return DefaultKernelSignatureMap::Instance().Get(op_type);
}; };
return func; return func;
......
...@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) { ...@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) {
config_ = std::move(config); config_ = std::move(config);
} }
void InferMetaContext::EmplaceBackInput( void InferMetaContext::EmplaceBackInput(MetaTensor input) {
std::shared_ptr<phi::MetaTensor> input) {
int index = inputs_.size(); int index = inputs_.size();
inputs_.emplace_back(std::move(input)); inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1)); input_range_.emplace_back(std::pair<int, int>(index, index + 1));
} }
void InferMetaContext::EmplaceBackOutput( void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
std::shared_ptr<phi::MetaTensor> output) {
int index = outputs_.size(); int index = outputs_.size();
outputs_.emplace_back(std::move(output)); outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1)); output_range_.emplace_back(std::pair<int, int>(index, index + 1));
...@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) { ...@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
} }
void InferMetaContext::EmplaceBackInputs( void InferMetaContext::EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs) { paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = inputs_.size(); int index = inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size())); input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(), inputs_.insert(inputs_.end(),
...@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs( ...@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end())); std::make_move_iterator(inputs.end()));
} }
void InferMetaContext::EmplaceBackOutputs( void InferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs) { paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs) {
int index = outputs_.size(); int index = outputs_.size();
output_range_.emplace_back( output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size())); std::pair<int, int>(index, index + outputs.size()));
...@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const { ...@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const {
const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const { const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx); return inputs_.at(idx);
} }
paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt( paddle::optional<const MetaTensor&> InferMetaContext::OptionalInputAt(
size_t idx) const { size_t idx) const {
const auto& input = inputs_.at(idx); const auto& input = inputs_.at(idx);
return input ? paddle::optional<const phi::MetaTensor&>{static_cast< return input.initialized()
const phi::MetaTensor&>(*input)} ? paddle::optional<const MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none}; : paddle::optional<const MetaTensor&>{paddle::none};
} }
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start, std::vector<const MetaTensor*> InferMetaContext::InputsBetween(
size_t end) const { size_t start, size_t end) const {
std::vector<MetaTensor*> result; std::vector<const MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get()); auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
} }
return result; return result;
...@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>> ...@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start); const auto& first = inputs_.at(start);
if (first) { if (first.initialized()) {
std::vector<const MetaTensor*> result; std::vector<const MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get()); auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
} }
return paddle::optional<const std::vector<const MetaTensor*>>(result); return paddle::optional<const std::vector<const MetaTensor*>>(result);
...@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { ...@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
} }
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get(); auto& out = outputs_.at(idx);
return out.initialized() ? &out : nullptr;
} }
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start, std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
...@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start, ...@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
std::vector<MetaTensor*> result; std::vector<MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.emplace_back(outputs_.at(i).get()); auto& out = outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
} }
return result; return result;
} }
......
...@@ -37,28 +37,28 @@ class InferMetaContext { ...@@ -37,28 +37,28 @@ class InferMetaContext {
explicit InferMetaContext(MetaConfig config) : config_(config) {} explicit InferMetaContext(MetaConfig config) : config_(config) {}
void SetMetaConfig(MetaConfig config); void SetMetaConfig(MetaConfig config);
void EmplaceBackInput(std::shared_ptr<phi::MetaTensor> input); const MetaConfig& GetMetaConfig() const;
void EmplaceBackOutput(std::shared_ptr<phi::MetaTensor> output);
void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr); void EmplaceBackAttr(paddle::any attr);
void EmplaceBackInputs( void EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs); paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs( void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs); paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs);
const std::pair<int, int>& InputRangeAt(size_t idx) const; virtual const MetaTensor& InputAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const; virtual paddle::optional<const MetaTensor&> OptionalInputAt(size_t idx) const;
const MetaConfig& GetMetaConfig() const; virtual std::vector<const MetaTensor*> InputsBetween(size_t start,
size_t end) const;
const MetaTensor& InputAt(size_t idx) const; virtual paddle::optional<const std::vector<const MetaTensor*>>
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const; OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx); virtual MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end); virtual std::vector<MetaTensor*> MutableOutputBetween(size_t start,
size_t end);
template <typename AttrType> template <typename AttrType>
AttrType AttrAt(size_t idx) { AttrType AttrAt(size_t idx) {
...@@ -73,19 +73,24 @@ class InferMetaContext { ...@@ -73,19 +73,24 @@ class InferMetaContext {
} }
} }
private: const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual ~InferMetaContext() = default;
protected:
MetaConfig config_; MetaConfig config_;
// NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
// objects are all created in each round, so we have to use smart pointer
// here, maybe we can implemented a new InferMetaContext and a series utils
// specifically for fluid to avoid using shared_ptr
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs_;
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs_;
paddle::SmallVector<paddle::any> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_; paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
paddle::SmallVector<std::pair<int, int>> output_range_; input_range_;
paddle::SmallVector<std::pair<int, int>, phi::kOutputSmallVectorSize>
output_range_;
private:
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs_;
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs_;
}; };
#define PD_INFER_META(...) \ #define PD_INFER_META(...) \
...@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}; };
template <typename... Tail> template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> { struct InferMetaFnCallHelper<const std::vector<const MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0, static_assert(attr_idx == 0,
...@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static_assert(out_idx == 0, static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs."); "InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor*> arg = std::vector<const MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second); ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper< InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx, Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
......
...@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) { ...@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) {
void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) { void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) { if (idx < input_range_.size()) {
input_range_[idx] = range; input_range_[idx] = std::move(range);
} else if (idx == input_range_.size()) { } else if (idx == input_range_.size()) {
input_range_.emplace_back(range); input_range_.emplace_back(range);
} else { } else {
...@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) { ...@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) { void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) { if (idx < output_range_.size()) {
output_range_[idx] = range; output_range_[idx] = std::move(range);
} else if (idx == output_range_.size()) { } else if (idx == output_range_.size()) {
output_range_.emplace_back(range); output_range_.emplace_back(range);
} else { } else {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
namespace phi { namespace phi {
const static Kernel empty_kernel; // NOLINT
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0; uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---| // |----31-20------|---19-12---|---11-8----|---7-0---|
...@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() { ...@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory; return g_op_kernel_factory;
} }
Kernel KernelFactory::SelectKernel(const std::string& kernel_name, const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) { if (iter == kernels_.end()) {
return Kernel(); return empty_kernel;
} }
auto kernel_iter = iter->second.find(kernel_key); auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) { if (kernel_iter == iter->second.end()) {
return Kernel(); return empty_kernel;
} }
return kernel_iter->second; return kernel_iter->second;
} }
...@@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap( ...@@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second; return iter->second;
} }
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, bool KernelFactory::HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
...@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( ...@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype)); KernelKey(backend, layout, dtype));
} }
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
return iter->second.cbegin()->second.args_def();
}
// print kernel info with json format: // print kernel info with json format:
// { // {
// "(CPU, Undefined(AnyLayout), complex64)": { // "(CPU, Undefined(AnyLayout), complex64)": {
......
...@@ -151,30 +151,38 @@ class KernelArgsDef { ...@@ -151,30 +151,38 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index)); attribute_defs_.emplace_back(AttributeArgDef(type_index));
} }
const paddle::SmallVector<TensorArgDef>& input_defs() const { const paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs()
const {
return input_defs_; return input_defs_;
} }
const paddle::SmallVector<TensorArgDef>& output_defs() const { const paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs()
const {
return output_defs_; return output_defs_;
} }
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const { const paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>&
attribute_defs() const {
return attribute_defs_; return attribute_defs_;
} }
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; } paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
return input_defs_;
}
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; } paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
return output_defs_;
}
paddle::SmallVector<AttributeArgDef>& attribute_defs() { paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>& attribute_defs() {
return attribute_defs_; return attribute_defs_;
} }
private: private:
paddle::SmallVector<TensorArgDef> input_defs_{{}}; paddle::SmallVector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}}; paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}}; paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
{}};
}; };
class Kernel { class Kernel {
...@@ -209,7 +217,7 @@ class Kernel { ...@@ -209,7 +217,7 @@ class Kernel {
TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); } TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }
bool IsValid() { return fn_ != nullptr; } bool IsValid() const { return fn_ != nullptr; }
private: private:
KernelFn fn_{nullptr}; KernelFn fn_{nullptr};
...@@ -246,14 +254,17 @@ class KernelFactory { ...@@ -246,14 +254,17 @@ class KernelFactory {
DataLayout layout, DataLayout layout,
DataType dtype) const; DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name, bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name, const Kernel& SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
KernelKeyMap SelectKernelMap(const std::string& kernel_name) const; KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const;
private: private:
KernelFactory() = default; KernelFactory() = default;
......
...@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ...@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
} }
} }
bool MetaTensor::initialized() const { return tensor_ != nullptr; }
} // namespace phi } // namespace phi
...@@ -45,10 +45,10 @@ class MetaTensor { ...@@ -45,10 +45,10 @@ class MetaTensor {
: tensor_(const_cast<TensorBase*>(&tensor)) {} : tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default; MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete; MetaTensor& operator=(MetaTensor&&) = default;
MetaTensor& operator=(MetaTensor&&) = delete; MetaTensor(const MetaTensor&) = default;
MetaTensor& operator=(const MetaTensor&) = default;
virtual ~MetaTensor() = default; virtual ~MetaTensor() = default;
...@@ -64,6 +64,8 @@ class MetaTensor { ...@@ -64,6 +64,8 @@ class MetaTensor {
virtual void share_meta(const MetaTensor& meta_tensor); virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor); virtual void share_dims(const MetaTensor& meta_tensor);
virtual bool initialized() const;
private: private:
// Because the lod in compiletime and runtime is different, // Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods // so `LoD` cannot in public methods
......
...@@ -22,7 +22,7 @@ class Kernel; ...@@ -22,7 +22,7 @@ class Kernel;
class KernelKey; class KernelKey;
class KernelArgsDef; class KernelArgsDef;
class KernelContext; class KernelContext;
class KernelSignature; struct KernelSignature;
class ArgumentMappingContext; class ArgumentMappingContext;
class InferMetaContext; class InferMetaContext;
...@@ -35,4 +35,9 @@ using ArgumentMappingFn = ...@@ -35,4 +35,9 @@ using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>; std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx); using InferMetaFn = void (*)(InferMetaContext* ctx);
// Global SmallVector size setting
constexpr size_t kInputSmallVectorSize = 10U;
constexpr size_t kAttrSmallVectorSize = 10U;
constexpr size_t kOutputSmallVectorSize = 5U;
} // namespace phi } // namespace phi
...@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x); dx->share_meta(x);
} }
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) { std::vector<MetaTensor*> inputs_grad) {
PADDLE_ENFORCE_GT(outputs_grad.size(), PADDLE_ENFORCE_GT(outputs_grad.size(),
1, 1,
...@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, ...@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
} }
} }
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x, void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad, const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad) { std::vector<MetaTensor*> x_grad) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive, bool adaptive,
MetaTensor* dx); MetaTensor* dx);
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad); std::vector<MetaTensor*> inputs_grad);
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x, void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad, const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad); std::vector<MetaTensor*> x_grad);
......
...@@ -21,7 +21,8 @@ limitations under the License. */ ...@@ -21,7 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi { namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) { std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors) {
std::vector<DDim> dims; std::vector<DDim> dims;
dims.reserve(tensors.size()); dims.reserve(tensors.size());
for (const MetaTensor* tensor : tensors) { for (const MetaTensor* tensor : tensors) {
...@@ -279,7 +280,7 @@ void AdamwInferMeta(const MetaTensor& param, ...@@ -279,7 +280,7 @@ void AdamwInferMeta(const MetaTensor& param,
master_param_outs); master_param_outs);
} }
void AddNInferMeta(const std::vector<MetaTensor*>& x, void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
auto N = x.size(); auto N = x.size();
...@@ -642,7 +643,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, ...@@ -642,7 +643,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) { std::vector<MetaTensor*> out) {
int target_rank = 0; int target_rank = 0;
const auto& input_dims = GetMetaTensorsDim(x); const auto& input_dims = GetMetaTensorsDim(x);
...@@ -696,7 +697,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, ...@@ -696,7 +697,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
} }
} }
void ConcatInferMeta(const std::vector<MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
...@@ -1488,7 +1489,7 @@ void InterpolateInferMeta( ...@@ -1488,7 +1489,7 @@ void InterpolateInferMeta(
} }
} }
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) { std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size(); const size_t inputs_num = inputs.size();
...@@ -1551,7 +1552,8 @@ void MomentumInferMeta(const MetaTensor& param, ...@@ -1551,7 +1552,8 @@ void MomentumInferMeta(const MetaTensor& param,
} }
} }
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x); auto inputs_dims = GetMetaTensorsDim(x);
const size_t inputs_num = inputs_dims.size(); const size_t inputs_num = inputs_dims.size();
...@@ -1624,7 +1626,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { ...@@ -1624,7 +1626,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins, void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids, const MetaTensor& ids,
MetaTensor* out) { MetaTensor* out) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -1803,8 +1805,8 @@ void RmspropInferMeta(const MetaTensor& param, ...@@ -1803,8 +1805,8 @@ void RmspropInferMeta(const MetaTensor& param,
} }
void RnnInferMeta(const MetaTensor& x, void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state, const std::vector<const MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list, const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length, paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob, float dropout_prob,
bool is_bidirec, bool is_bidirec,
...@@ -1910,7 +1912,7 @@ void SgdInferMeta(const MetaTensor& param, ...@@ -1910,7 +1912,7 @@ void SgdInferMeta(const MetaTensor& param,
param_out->set_dtype(param.dtype()); param_out->set_dtype(param.dtype());
} }
void StackInferMeta(const std::vector<MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out) { MetaTensor* out) {
PADDLE_ENFORCE_GT(x.size(), PADDLE_ENFORCE_GT(x.size(),
...@@ -1956,7 +1958,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x, ...@@ -1956,7 +1958,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) { std::vector<MetaTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) { for (size_t i = 0; i < x.size(); ++i) {
out[i]->share_meta(*x[i]); out[i]->share_meta(*x[i]);
......
...@@ -35,7 +35,8 @@ namespace phi { ...@@ -35,7 +35,8 @@ namespace phi {
// //
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order. // NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors); std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors);
void AdadeltaInferMeta(const MetaTensor& param, void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad, const MetaTensor& grad,
...@@ -117,7 +118,7 @@ void AdamwInferMeta(const MetaTensor& param, ...@@ -117,7 +118,7 @@ void AdamwInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out, MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs); MetaTensor* master_param_outs);
void AddNInferMeta(const std::vector<MetaTensor*>& x, void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
...@@ -173,10 +174,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, ...@@ -173,10 +174,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void ConcatInferMeta(const std::vector<MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
...@@ -227,7 +228,7 @@ void InterpolateInferMeta( ...@@ -227,7 +228,7 @@ void InterpolateInferMeta(
MetaTensor* output, MetaTensor* output,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs); std::vector<MetaTensor*> outputs);
void MomentumInferMeta(const MetaTensor& param, void MomentumInferMeta(const MetaTensor& param,
...@@ -245,9 +246,10 @@ void MomentumInferMeta(const MetaTensor& param, ...@@ -245,9 +246,10 @@ void MomentumInferMeta(const MetaTensor& param,
MetaTensor* velocity_out, MetaTensor* velocity_out,
MetaTensor* master_param_out); MetaTensor* master_param_out);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out); void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins, void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids, const MetaTensor& ids,
MetaTensor* out); MetaTensor* out);
...@@ -276,8 +278,8 @@ void RmspropInferMeta(const MetaTensor& param, ...@@ -276,8 +278,8 @@ void RmspropInferMeta(const MetaTensor& param,
MetaTensor* mean_grad_out); MetaTensor* mean_grad_out);
void RnnInferMeta(const MetaTensor& x, void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state, const std::vector<const MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list, const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length, paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob, float dropout_prob,
bool is_bidirec, bool is_bidirec,
...@@ -300,11 +302,11 @@ void SgdInferMeta(const MetaTensor& param, ...@@ -300,11 +302,11 @@ void SgdInferMeta(const MetaTensor& param,
MetaTensor* param_out, MetaTensor* param_out,
MetaTensor* master_param_out); MetaTensor* master_param_out);
void StackInferMeta(const std::vector<MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
......
...@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx, ...@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx,
const Scalar& axis) { const Scalar& axis) {
std::vector<MetaTensor> meta_x; std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size()); meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr; std::vector<const MetaTensor*> meta_x_ptr;
for (const auto* t : x) { for (const auto* t : x) {
meta_x.emplace_back(*t); meta_x.emplace_back(*t);
meta_x_ptr.push_back(&meta_x.back()); meta_x_ptr.push_back(&meta_x.back());
......
...@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("abs_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
KernelSignature AbsDoubleGradOpArgumentMapping( KernelSignature AbsDoubleGradOpArgumentMapping(
......
...@@ -19,26 +19,22 @@ namespace phi { ...@@ -19,26 +19,22 @@ namespace phi {
#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature( \
{"X", GradVarName("Out")}, \ op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
{attrs}, \
{GradVarName("X")}); \
} }
#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature( \
{"Out", GradVarName("Out")}, \ op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
{attrs}, \
{GradVarName("X")}); \
} }
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature( \ return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \ op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"}); \
} }
#define comma , #define comma ,
...@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("logit_grad", {"X", "Out@GRAD"}, {"eps"}, {"X@GRAD"});
"logit_grad", {"X", GradVarName("Out")}, {"eps"}, {GradVarName("X")});
} }
KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu_grad", return KernelSignature(
{"X", "Out", GradVarName("Out")}, "elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"});
{"alpha"},
{GradVarName("X")});
} }
KernelSignature EluDoubleGradOpArgumentMapping( KernelSignature EluDoubleGradOpArgumentMapping(
...@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) { if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_grad", return KernelSignature(
{"X", GradVarName("Out")}, "pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"});
{"FactorTensor"},
{GradVarName("X")});
} else { } else {
return KernelSignature( return KernelSignature(
"pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")}); "pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"});
} }
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> in_names = {"Param", paddle::SmallVector<const char*> in_names = {"Param",
"Grad", "Grad",
"LearningRate", "LearningRate",
"Moment1", "Moment1",
...@@ -28,13 +28,13 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,13 +28,13 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
"Beta2Pow", "Beta2Pow",
"MasterParam", "MasterParam",
"SkipUpdate"}; "SkipUpdate"};
paddle::SmallVector<std::string> out_names = {"ParamOut", paddle::SmallVector<const char*> out_names = {"ParamOut",
"Moment1Out", "Moment1Out",
"Moment2Out", "Moment2Out",
"Beta1PowOut", "Beta1PowOut",
"Beta2PowOut", "Beta2PowOut",
"MasterParamOut"}; "MasterParamOut"};
paddle::SmallVector<std::string> attr_names; paddle::SmallVector<const char*> attr_names;
attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor" attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor"
: "beta1"); : "beta1");
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace phi { namespace phi {
KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> in_names = {"Param", paddle::SmallVector<const char*> in_names = {"Param",
"Grad", "Grad",
"LearningRate", "LearningRate",
"Moment1", "Moment1",
...@@ -28,13 +28,13 @@ KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,13 +28,13 @@ KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) {
"Beta2Pow", "Beta2Pow",
"MasterParam", "MasterParam",
"SkipUpdate"}; "SkipUpdate"};
paddle::SmallVector<std::string> out_names = {"ParamOut", paddle::SmallVector<const char*> out_names = {"ParamOut",
"Moment1Out", "Moment1Out",
"Moment2Out", "Moment2Out",
"Beta1PowOut", "Beta1PowOut",
"Beta2PowOut", "Beta2PowOut",
"MasterParamOut"}; "MasterParamOut"};
paddle::SmallVector<std::string> attr_names; paddle::SmallVector<const char*> attr_names;
attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor" attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor"
: "beta1"); : "beta1");
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
namespace phi { namespace phi {
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("addmm_grad",
"addmm_grad", {"Input", "X", "Y", "Out@GRAD"},
{"Input", "X", "Y", GradVarName("Out")}, {"Alpha", "Beta"},
{"Alpha", "Beta"}, {"Input@GRAD", "X@GRAD", "Y@GRAD"});
{GradVarName("Input"), GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature ArgsortGradOpArgumentMapping( KernelSignature ArgsortGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("argsort_grad", return KernelSignature("argsort_grad",
{"Indices", "X", GradVarName("Out")}, {"Indices", "X", "Out@GRAD"},
{"axis", "descending"}, {"axis", "descending"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
namespace phi { namespace phi {
KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("atan2_grad", return KernelSignature(
{"X1", "X2", GradVarName("Out")}, "atan2_grad", {"X1", "X2", "Out@GRAD"}, {}, {"X1@GRAD", "X2@GRAD"});
{},
{GradVarName("X1"), GradVarName("X2")});
} }
} // namespace phi } // namespace phi
......
...@@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature BatchNormGradOpArgumentMapping( KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("batch_norm_grad",
"batch_norm_grad", {
{ "X",
"X", "Scale",
"Scale", "Bias",
"Bias", "Mean",
"Mean", "Variance",
"Variance", "SavedMean",
"SavedMean", "SavedVariance",
"SavedVariance", "ReserveSpace",
"ReserveSpace", "Y@GRAD",
GradVarName("Y"), },
}, {"momentum",
{"momentum", "epsilon",
"epsilon", "data_layout",
"data_layout", "is_test",
"is_test", "use_global_stats",
"use_global_stats", "trainable_statistics",
"trainable_statistics", "fuse_with_relu"},
"fuse_with_relu"}, {"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
} }
KernelSignature BatchNormGradGradOpArgumentMapping( KernelSignature BatchNormGradGradOpArgumentMapping(
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature BCELossGradOpArgumentMapping( KernelSignature BCELossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("bce_loss_grad", return KernelSignature(
{"X", "Label", GradVarName("Out")}, "bce_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping( ...@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping(
KernelSignature BilinearTensorProductGradOpArgumentMapping( KernelSignature BilinearTensorProductGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_tensor_product_grad", return KernelSignature("bilinear_tensor_product_grad",
{"X", "Y", "Weight", GradVarName("Out")}, {"X", "Y", "Weight", "Out@GRAD"},
{}, {},
{GradVarName("X"), {"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"});
GradVarName("Y"),
GradVarName("Weight"),
GradVarName("Bias")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,7 +19,7 @@ namespace phi { ...@@ -19,7 +19,7 @@ namespace phi {
KernelSignature BroadcastTensorsGradOpArgumentMapping( KernelSignature BroadcastTensorsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"broadcast_tensors_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); "broadcast_tensors_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CholeskyGradOpArgumentMapping( KernelSignature CholeskyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_grad", return KernelSignature(
{"Out", GradVarName("Out")}, "cholesky_grad", {"Out", "Out@GRAD"}, {"upper"}, {"X@GRAD"});
{"upper"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature CholeskySolveGradOpArgumentMapping( KernelSignature CholeskySolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_solve_grad", return KernelSignature("cholesky_solve_grad",
{"X", "Y", "Out", GradVarName("Out")}, {"X", "Y", "Out", "Out@GRAD"},
{"upper"}, {"upper"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace phi { namespace phi {
KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> attr_names; paddle::SmallVector<std::string, kAttrSmallVectorSize> attr_names;
attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min");
attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max");
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
...@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Min")) { if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) { if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"});
{"Min", "Max"},
{GradVarName("X")});
} else { } else {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"});
{"Min", "max"},
{GradVarName("X")});
} }
} else { } else {
if (ctx.HasInput("Max")) { if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"});
{"min", "Max"},
{GradVarName("X")});
} else { } else {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"});
{"min", "max"},
{GradVarName("X")});
} }
} }
} }
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
namespace phi { namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
"real_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
} }
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
"imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat_grad", return KernelSignature(
{"X", {GradVarName("Out")}}, "concat_grad", {"X", {"Out@GRAD"}}, {"AxisTensor"}, {{"X@GRAD"}});
{"AxisTensor"},
{{GradVarName("X")}});
} }
return KernelSignature("concat_grad", return KernelSignature(
{"X", {GradVarName("Out")}}, "concat_grad", {"X", {"Out@GRAD"}}, {"axis"}, {{"X@GRAD"}});
{"axis"},
{{GradVarName("X")}});
} }
} // namespace phi } // namespace phi
......
...@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad", return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto", "use_addto",
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search"}, "exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv2dDoubleGradOpArgumentMapping( KernelSignature Conv2dDoubleGradOpArgumentMapping(
......
...@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad", return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto", "use_addto",
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search"}, "exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv3dDoubleGradOpArgumentMapping( KernelSignature Conv3dDoubleGradOpArgumentMapping(
......
...@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping( ...@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping(
KernelSignature Conv2dTransposeGradOpArgumentMapping( KernelSignature Conv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_transpose_grad", return KernelSignature("conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping( ...@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping( KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping(
...@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping( ...@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping(
KernelSignature Conv3dTransposeGradOpArgumentMapping( KernelSignature Conv3dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("conv3d_transpose_grad", return KernelSignature("conv3d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping( ...@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
...@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( ...@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_transpose_grad", return KernelSignature("depthwise_conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( ...@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cross_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "cross_grad", {"X", "Y", "Out@GRAD"}, {"dim"}, {"X@GRAD", "Y@GRAD"});
{"dim"},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CumprodGradGradOpArgumentMapping( KernelSignature CumprodGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cumprod_grad", return KernelSignature(
{"X", "Out", GradVarName("Out")}, "cumprod_grad", {"X", "Out", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
{"dim"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping( ...@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"deformable_conv_grad", "deformable_conv_grad",
{"Input", "Offset", "Filter", "Mask", GradVarName("Output")}, {"Input", "Offset", "Filter", "Mask", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"dilations", "dilations",
"deformable_groups", "deformable_groups",
"groups", "groups",
"im2col_step"}, "im2col_step"},
{GradVarName("Input"), {"Input@GRAD", "Offset@GRAD", "Filter@GRAD", "Mask@GRAD"});
GradVarName("Offset"),
GradVarName("Filter"),
GradVarName("Mask")});
} }
} // namespace phi } // namespace phi
......
...@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping( ...@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping(
KernelSignature DepthwiseConv2dGradOpArgumentMapping( KernelSignature DepthwiseConv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_grad", return KernelSignature("depthwise_conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping( ...@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping(
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search", "exhaustive_search",
"fuse_relu_before_depthwise_conv"}, "fuse_relu_before_depthwise_conv"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping( KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping(
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature DeterminantGradOpArgumentMapping( KernelSignature DeterminantGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("determinant_grad", return KernelSignature(
{"Input", "Out", GradVarName("Out")}, "determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"});
{},
{GradVarName("Input")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")}); "diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature DiagonalGradOpArgumentMapping( KernelSignature DiagonalGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("diagonal_grad", return KernelSignature("diagonal_grad",
{"Input", GradVarName("Out")}, {"Input", "Out@GRAD"},
{"offset", "axis1", "axis2"}, {"offset", "axis1", "axis2"},
{GradVarName("Input")}); {"Input@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,8 +18,7 @@ namespace phi { ...@@ -18,8 +18,7 @@ namespace phi {
KernelSignature DigammaGradOpArgumentMapping( KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ limitations under the License. */ ...@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dist_grad", return KernelSignature(
{"X", "Y", "Out", GradVarName("Out")}, "dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"});
{"p"},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ limitations under the License. */ ...@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
{},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DropoutGradOpArgumentMapping( KernelSignature DropoutGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_grad", return KernelSignature("dropout_grad",
{"Mask", GradVarName("Out")}, {"Mask", "Out@GRAD"},
{"dropout_prob", "is_test", "dropout_implementation"}, {"dropout_prob", "is_test", "dropout_implementation"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
namespace phi { namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigh_grad", return KernelSignature(
{"Eigenvalues", "eigh_grad",
"Eigenvectors", {"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
GradVarName("Eigenvalues"), {},
GradVarName("Eigenvectors")}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping( ...@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping(
KernelSignature ElementwiseAddGradOpArgumentMapping( KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("add_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
...@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping( ...@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping( KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("subtract_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
...@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( ...@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
KernelSignature ElementwiseDivGradOpArgumentMapping( KernelSignature ElementwiseDivGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("divide_grad", return KernelSignature("divide_grad",
{"X", "Y", "Out", GradVarName("Out")}, {"X", "Y", "Out", "Out@GRAD"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
KernelSignature ElementwiseFMinGradOpArgumentMapping( KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
...@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( ...@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
return KernelSignature("divide_double_grad", return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"}, {"Y", "Out", "DX", "DDX", "DDY"},
{"axis"}, {"axis"},
{GradVarName("Y"), "DOut", "DDOut"}); {"Y@GRAD", "DOut", "DDOut"});
} }
KernelSignature ElementwiseMulGradOpArgumentMapping( KernelSignature ElementwiseMulGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseFMaxOpArgumentMapping( KernelSignature ElementwiseFMaxOpArgumentMapping(
...@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping( ...@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
KernelSignature ElementwiseFMaxGradOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmax_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
...@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( ...@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
return KernelSignature("multiply_double_grad", return KernelSignature("multiply_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"}, {"X", "Y", "DOut", "DDX", "DDY"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y"), "DDOut"}); {"X@GRAD", "Y@GRAD", "DDOut"});
} }
KernelSignature ElementwiseMulTripleGradOpArgumentMapping( KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
...@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping( ...@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping( KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("maximum_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("minimum_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwisePowGradOpArgumentMapping( KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad", return KernelSignature("elementwise_pow_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", "Out@GRAD"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping( ...@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping(
if (ctx.IsDenseTensorInput("W")) { if (ctx.IsDenseTensorInput("W")) {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) { if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("embedding_sparse_grad", return KernelSignature("embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} else { } else {
return KernelSignature("embedding_grad", return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} }
} else { } else {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) { if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("sparse_weight_embedding_sparse_grad", return KernelSignature("sparse_weight_embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} else { } else {
return KernelSignature("sparse_weight_embedding_grad", return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} }
} }
} }
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("erfinv_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"});
"erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandAsGradOpArgumentMapping( KernelSignature ExpandAsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_as_grad", {"X", "Out@GRAD"}, {"target_shape"}, {"X@GRAD"});
{"target_shape"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -28,20 +28,14 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,20 +28,14 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"});
{"Shape"},
{GradVarName("X")});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) { } else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"expand_shapes_tensor"}, {"X@GRAD"});
{"expand_shapes_tensor"},
{GradVarName("X")});
} else { } else {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
{"shape"},
{GradVarName("X")});
} }
} }
......
...@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping( KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")}); "flatten_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping( ...@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping( KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("frobenius_norm_grad", return KernelSignature("frobenius_norm_grad",
{"X", "Out", GradVarName("Out")}, {"X", "Out", "Out@GRAD"},
{"dim", "keep_dim", "reduce_all"}, {"dim", "keep_dim", "reduce_all"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,25 +17,23 @@ ...@@ -17,25 +17,23 @@
namespace phi { namespace phi {
KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gather_nd_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_grad", return KernelSignature("scatter_grad",
{"Ids", "Updates", GradVarName("Out")}, {"Ids", "Updates", "Out@GRAD"},
{"overwrite"}, {"overwrite"},
{GradVarName("X"), GradVarName("Updates")}); {"X@GRAD", "Updates@GRAD"});
} }
KernelSignature ScatterNdAddGradArgumentMapping( KernelSignature ScatterNdAddGradArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_nd_add_grad", return KernelSignature("scatter_nd_add_grad",
{"Index", "Updates", GradVarName("Out")}, {"Index", "Updates", "Out@GRAD"},
{}, {},
{GradVarName("X"), GradVarName("Updates")}); {"X@GRAD", "Updates@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad", return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")}, {"X", "Index", "Out@GRAD"},
{"Axis", "overwrite"}, {"Axis", "overwrite"},
{GradVarName("X")}); {"X@GRAD"});
} else { } else {
return KernelSignature("gather_grad", return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")}, {"X", "Index", "Out@GRAD"},
{"axis", "overwrite"}, {"axis", "overwrite"},
{GradVarName("X")}); {"X@GRAD"});
} }
} }
......
...@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu_grad", return KernelSignature(
{"X", GradVarName("Out")}, "gelu_grad", {"X", "Out@GRAD"}, {"approximate"}, {"X@GRAD"});
{"approximate"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( ...@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"graph_send_recv_grad", "graph_send_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")}, {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"pool_type"}, {"pool_type"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping( ...@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping(
KernelSignature GridSamplerGradOpArgumentMapping( KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad", return KernelSignature("grid_sample_grad",
{"X", "Grid", GradVarName("Output")}, {"X", "Grid", "Output@GRAD"},
{"mode", "padding_mode", "align_corners"}, {"mode", "padding_mode", "align_corners"},
{GradVarName("X"), GradVarName("Grid")}); {"X@GRAD", "Grid@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxGradOpArgumentMapping( KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("gumbel_softmax_grad", return KernelSignature(
{"Out", GradVarName("Out")}, "gumbel_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
{"axis"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping( ...@@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping(
KernelSignature HierarchicalSigmoidGradOpArgumentMapping( KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput(GradVarName("W"))) { if (ctx.IsDenseTensorOutput("W@GRAD")) {
return KernelSignature( return KernelSignature("hierarchical_sigmoid_grad",
"hierarchical_sigmoid_grad", {"X",
{"X", "W",
"W", "Label",
"Label", "PathTable",
"PathTable", "PathCode",
"PathCode", "Bias",
"Bias", "PreOut",
"PreOut", "Out@GRAD"},
GradVarName("Out")}, {"num_classes",
{"num_classes", "remote_prefetch",
"remote_prefetch", "trainer_id",
"trainer_id", "height_sections",
"height_sections", "epmap",
"epmap", "table_names",
"table_names", "is_sparse"},
"is_sparse"}, {"X@GRAD", "W@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); } else if (ctx.IsSelectedRowsOutput("W@GRAD")) {
} else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) { return KernelSignature("hierarchical_sigmoid_grad_sr",
return KernelSignature( {"X",
"hierarchical_sigmoid_grad_sr", "W",
{"X", "Label",
"W", "PathTable",
"Label", "PathCode",
"PathTable", "Bias",
"PathCode", "PreOut",
"Bias", "Out@GRAD"},
"PreOut", {"num_classes",
GradVarName("Out")}, "remote_prefetch",
{"num_classes", "trainer_id",
"remote_prefetch", "height_sections",
"trainer_id", "epmap",
"height_sections", "table_names",
"epmap", "is_sparse"},
"table_names", {"X@GRAD", "W@GRAD", "Bias@GRAD"});
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else { } else {
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
...@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature HuberLossGradOpArgumentMapping( KernelSignature HuberLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("huber_loss_grad", return KernelSignature("huber_loss_grad",
{"Residual", GradVarName("Out")}, {"Residual", "Out@GRAD"},
{"delta"}, {"delta"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping( KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping( KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
{"dim"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping( ...@@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping(
KernelSignature BilinearInterpGradOpArgumentMapping( KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("bilinear_interp_v2_grad",
"bilinear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature NearestInterpGradOpArgumentMapping( KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("nearest_interp_v2_grad",
"nearest_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature TrilinearInterpGradOpArgumentMapping( KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("trilinear_interp_v2_grad",
"trilinear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature LinearInterpGradOpArgumentMapping( KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("linear_interp_v2_grad",
"linear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature BicubicInterpGradOpArgumentMapping( KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("bicubic_interp_v2_grad",
"bicubic_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -20,9 +20,9 @@ namespace phi { ...@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KLDivLossGradOpArgumentMapping( KernelSignature KLDivLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("kldiv_loss_grad", return KernelSignature("kldiv_loss_grad",
{"X", "Target", GradVarName("Loss")}, {"X", "Target", "Loss@GRAD"},
{"reduction"}, {"reduction"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
namespace phi { namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
{},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -20,9 +20,9 @@ namespace phi { ...@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KthvalueGradOpArgumentMapping( KernelSignature KthvalueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("kthvalue_grad", return KernelSignature("kthvalue_grad",
{"X", "Indices", GradVarName("Out")}, {"X", "Indices", "Out@GRAD"},
{"k", "axis", "keepdim"}, {"k", "axis", "keepdim"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping( ...@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping(
KernelSignature LabelSmoothGradOpArgumentMapping( KernelSignature LabelSmoothGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("label_smooth_grad", return KernelSignature(
{GradVarName("Out")}, "label_smooth_grad", {"Out@GRAD"}, {"epsilon"}, {"X@GRAD"});
{"epsilon"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LayerNormGradOpArgumentMapping( KernelSignature LayerNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("layer_norm_grad",
"layer_norm_grad", {"X", "Scale", "Bias", "Mean", "Variance", "Y@GRAD"},
{"X", "Scale", "Bias", "Mean", "Variance", GradVarName("Y")}, {"epsilon", "begin_norm_axis", "is_test"},
{"epsilon", "begin_norm_axis", "is_test"}, {"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad", return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", GradVarName("Out")}, {"X", "Y", "Weight", "Out", "Out@GRAD"},
{}, {},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogLossGradOpArgumentMapping( KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad", return KernelSignature("log_loss_grad",
{"Predicted", "Labels", GradVarName("Loss")}, {"Predicted", "Labels", "Loss@GRAD"},
{"epsilon"}, {"epsilon"},
{GradVarName("Predicted")}); {"Predicted@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature LogSoftmaxGradOpArgumentMapping( KernelSignature LogSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("log_softmax_grad", return KernelSignature(
{"Out", GradVarName("Out")}, "log_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
{"axis"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogsumexpGradOpArgumentMapping( KernelSignature LogsumexpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("logsumexp_grad", return KernelSignature("logsumexp_grad",
{"X", "Out", GradVarName("Out")}, {"X", "Out", "Out@GRAD"},
{"axis", "keepdim", "reduce_all"}, {"axis", "keepdim", "reduce_all"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -23,10 +23,8 @@ KernelSignature MaskedSelectOpArgumentMapping( ...@@ -23,10 +23,8 @@ KernelSignature MaskedSelectOpArgumentMapping(
KernelSignature MaskedSelectGradOpArgumentMapping( KernelSignature MaskedSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select_grad", return KernelSignature(
{"X", "Mask", GradVarName("Y")}, "masked_select_grad", {"X", "Mask", "Y@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,14 +19,14 @@ namespace phi { ...@@ -19,14 +19,14 @@ namespace phi {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasAttr("use_addto")) { if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad", return KernelSignature("addto_matmul_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y", "use_addto"}, {"trans_x", "trans_y", "use_addto"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} else { } else {
return KernelSignature("matmul_grad", return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y"}, {"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册