未验证 提交 ec1d2a16 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Optimize dygraph scheduling performance (#42010)

* [Phi] Support setting size of vector<Tensor> for out in yaml (#41576)

* support setting vector out size in yaml

* support setting size of vector<tensor> for out in yaml

* resolve conflict
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
上级 e4cb897e
...@@ -235,7 +235,7 @@ def ParseYamlReturns(string): ...@@ -235,7 +235,7 @@ def ParseYamlReturns(string):
returns = [x.strip() for x in string.strip().split(",")] returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)): for i in range(len(returns)):
ret = returns[i] ret = returns[i].split("{")[0].strip()
ret_name = "" ret_name = ""
if "(" in ret and ")" in ret: if "(" in ret and ")" in ret:
......
...@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) { ...@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_lod(meta_tensor); share_lod(meta_tensor);
} }
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
const std::string& op_type) { int index = compat_inputs_.size();
compat_inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
int index = compat_outputs_.size();
compat_outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = compat_inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
compat_inputs_.insert(compat_inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void CompatInferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs) {
int index = compat_outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
compat_outputs_.insert(compat_outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
return compat_inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&>
CompatInferMetaContext::OptionalInputAt(size_t idx) const {
const auto& input = compat_inputs_.at(idx);
return input.initialized()
? paddle::optional<const phi::MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
}
paddle::optional<const std::vector<const phi::MetaTensor*>>
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = compat_inputs_.at(start);
if (first.initialized()) {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(
paddle::none);
}
phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
auto& out = compat_outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
size_t start, size_t end) {
std::vector<phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& out = compat_outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args // 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound( arg_map_fn, platform::errors::NotFound(
...@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context // 2. build infermeta context
phi::InferMetaContext infer_meta_context( CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()}); {ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args); auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args); auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args); auto& output_names = std::get<2>(signature.args);
auto kernels_map = const auto& args_def =
phi::KernelFactory::Instance().SelectKernelMap(signature.name); phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
if (kernels_map.size() == 0) { const auto& attr_defs = args_def.attribute_defs();
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) { for (auto& in_name : input_names) {
if (ctx->HasInputs(in_name)) { if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name); auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
if (input_var.size() == 1) { if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput( infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime())); std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
} else { } else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs; paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
inputs.reserve(input_var.size()); inputs;
for (const auto& in : input_var) { for (const auto& in : input_var) {
inputs.push_back( inputs.emplace_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime())); std::move(CompatMetaTensor(in, ctx->IsRuntime())));
} }
infer_meta_context.EmplaceBackInputs(std::move(inputs)); infer_meta_context.EmplaceBackInputs(std::move(inputs));
} }
} else { } else {
infer_meta_context.EmplaceBackInput({nullptr}); infer_meta_context.EmplaceBackInput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
} }
VLOG(6) << "BuildInferMetaContext: Done inputs";
auto attr_reader = ctx->Attrs(); auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i]; auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) {
// When attr is a vector_tensor or tensor, transform it to IntArray // When attr is a vector_tensor or tensor, transform it to IntArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for IntArray // If is in runtime, we will get tensor's value for IntArray
// and push it into attrs // and push it into attrs
...@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name)); attr_name));
} }
} else if (ctx->HasInput(attr_name)) { } else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name); auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
if (infershape_input.size() == 1) { if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
...@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// convert from data // convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp); auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>(); int32_t val_int = val.template to<int32_t>();
...@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} }
VLOG(6) << "BuildInferMetaContext: Done attrs";
for (auto& out_name : output_names) { for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name, true)) { if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name); auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
if (output_var.size() == 1) { if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>( infer_meta_context.EmplaceBackOutput(
output_var[0], ctx->IsRuntime())); std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
} else { } else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs; paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs.reserve(output_var.size()); outputs;
for (const auto& out : output_var) { for (const auto& out : output_var) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
if (BOOST_GET_CONST(Variable*, out)) { if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back( outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime())); std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue; continue;
} }
} else if (BOOST_GET_CONST(VarDesc*, out)) { } else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back( outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime())); std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue; continue;
} }
outputs.emplace_back(nullptr); outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
infer_meta_context.EmplaceBackOutputs(std::move(outputs)); infer_meta_context.EmplaceBackOutputs(std::move(outputs));
} }
} else { } else {
infer_meta_context.EmplaceBackOutput({nullptr}); infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
} }
} }
VLOG(6) << "BuildInferMetaContext: Done outputs";
return infer_meta_context; return infer_meta_context;
} }
......
...@@ -18,38 +18,24 @@ limitations under the License. */ ...@@ -18,38 +18,24 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
namespace phi {
class InferMetaContext;
} // namespace phi
namespace paddle { namespace paddle {
namespace framework { namespace framework {
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
// TODO(chenweihang): Support TensorArray later // TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor { class CompatMetaTensor : public phi::MetaTensor {
public: public:
explicit CompatMetaTensor(bool is_runtime)
: is_runtime_(is_runtime), initialized_(false) {}
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {} : var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default; CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; CompatMetaTensor& operator=(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = default;
int64_t numel() const override; int64_t numel() const override;
...@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override; void share_meta(const MetaTensor& meta_tensor) override;
bool initialized() const override { return initialized_; };
private: private:
const LoD& GetRuntimeLoD() const { const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
...@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor {
InferShapeVarPtr var_; InferShapeVarPtr var_;
bool is_runtime_; bool is_runtime_;
bool initialized_{true};
};
// Note: In order to avoid using shared_ptr to manage MetaTensor in
// InferMetaContext, inherit and implement InferMetaContext separately
// for compatibility with fluid, shared_ptr will cause significant decrease
// in scheduling performance
class CompatInferMetaContext : public phi::InferMetaContext {
public:
CompatInferMetaContext() = default;
explicit CompatInferMetaContext(phi::MetaConfig config)
: phi::InferMetaContext(config) {}
void EmplaceBackInput(CompatMetaTensor input);
void EmplaceBackOutput(CompatMetaTensor output);
void EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs);
const phi::MetaTensor& InputAt(size_t idx) const override;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(
size_t idx) const override;
std::vector<const phi::MetaTensor*> InputsBetween(size_t start,
size_t end) const override;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const override;
phi::MetaTensor* MutableOutputAt(size_t idx) override;
std::vector<phi::MetaTensor*> MutableOutputBetween(size_t start,
size_t end) override;
virtual ~CompatInferMetaContext() = default;
private:
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
compat_inputs_;
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
compat_outputs_;
}; };
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const { ...@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
} }
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const { const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
std::vector<InferShapeVarPtr> paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs( InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const { const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name); const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
......
...@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRunMKLDNNKernel() const override; bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override; GetInputVarPtrs(const std::string& name) const override;
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override; GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override; DDim GetInputDim(const std::string& name) const override;
......
...@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
} }
} }
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string &name) const override { GetInputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Inputs(name); const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(arg_names.size()); res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { [this](const std::string &name) {
...@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return res; return res;
} }
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string &name) const override { GetOutputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Outputs(name); const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(arg_names.size()); res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res), std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { [this](const std::string &name) {
......
...@@ -947,19 +947,19 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -947,19 +947,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override { GetInputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name); const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
} }
std::vector<InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override { GetOutputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = OutputVars(name); const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res; paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size()); res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end()); res.insert(res.begin(), vars.begin(), vars.end());
return res; return res;
...@@ -1326,8 +1326,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1326,8 +1326,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< ", using_kernel_key:" << *kernel_type_.get(); << ", using_kernel_key:" << *kernel_type_.get();
auto try_pt_kernel_key = auto try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
if (!phi::KernelFactory::Instance().IsSelectKernelValid( if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
pt_kernel_name, try_pt_kernel_key)) { try_pt_kernel_key)) {
kernel_type_->library_type_ = expected_kernel_key_library_type; kernel_type_->library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel in static graph: " << type_ VLOG(3) << "modify XPU KP kernel in static graph: " << type_
<< " is failed " << *kernel_type_.get(); << " is failed " << *kernel_type_.get();
...@@ -2115,10 +2115,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -2115,10 +2115,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs( KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx); ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())( if (arg_map_fn_ == nullptr) {
arg_mapping_ctx); arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
}
return (*arg_map_fn_)(arg_mapping_ctx);
} }
Scope* OperatorWithKernel::PreparePhiData( Scope* OperatorWithKernel::PreparePhiData(
......
...@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_kp_kernel = false; mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_; mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_; mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
~KernelArgsNameMakerByOpProto() {} ~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override; const paddle::SmallVector<const char*>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override; const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override; const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature(); KernelSignature GetKernelSignature();
...@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { ...@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
private: private:
const framework::proto::OpProto* op_proto_; const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_; paddle::SmallVector<const char*> input_names_;
paddle::SmallVector<std::string> output_names_; paddle::SmallVector<const char*> output_names_;
paddle::SmallVector<std::string> attr_names_; paddle::SmallVector<const char*> attr_names_;
}; };
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) { OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
...@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_xpu_place(expected_kernel_key.place_) || if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) { paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type() VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) { if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing NPU kernel: " << op.Type() VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) { if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing MLU kernel: " << op.Type() VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) { if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing IPU kernel: " << op.Type() VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_custom_place(expected_kernel_key.place_)) { if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << op.Type() << " kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!"; << ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(), return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype()); kernel_key.dtype());
...@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey(); return phi::KernelKey();
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() { KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) { for (int i = 0; i < op_proto_->inputs_size(); ++i) {
auto& in = op_proto_->inputs()[i]; auto& in = op_proto_->inputs()[i];
auto& in_name = in.name(); auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) { if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name;
continue; continue;
} }
// If contains dispensable input, we should override the // If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir // OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) { if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel input: " << in_name; input_names_.emplace_back(in_name.c_str());
input_names_.emplace_back(in_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel inputs: ";
std::copy(input_names_.begin(), input_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return input_names_; return input_names_;
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetOutputArgsNames() { KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
for (int i = 0; i < op_proto_->outputs_size(); ++i) { for (int i = 0; i < op_proto_->outputs_size(); ++i) {
auto& out = op_proto_->outputs()[i]; auto& out = op_proto_->outputs()[i];
auto& out_name = out.name(); auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) { if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel output: " << out_name; output_names_.emplace_back(out_name.c_str());
output_names_.emplace_back(out_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel outputs: ";
std::copy(output_names_.begin(), output_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return output_names_; return output_names_;
} }
const paddle::SmallVector<std::string>& const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
for (int i = 0; i < op_proto_->attrs_size(); ++i) { for (int i = 0; i < op_proto_->attrs_size(); ++i) {
auto& attr = op_proto_->attrs()[i]; auto& attr = op_proto_->attrs()[i];
...@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" || attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" || attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") { attr_name == "op_device") {
VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name;
continue; continue;
} }
if ((attr.has_extra() && attr.extra()) || if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) { (attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name;
continue; continue;
} }
VLOG(6) << "Parse PhiKernel attribute: " << attr_name; attr_names_.emplace_back(attr_name.c_str());
attr_names_.emplace_back(attr_name); }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel attributes: ";
std::copy(attr_names_.begin(), attr_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
} }
return attr_names_; return attr_names_;
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()), return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
GetInputArgsNames(), GetAttrsArgsNames(), GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames()); GetOutputArgsNames());
} }
...@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() { ...@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) && if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) { op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type; VLOG(10) << "Register `" << op_type << "` kernel signature:";
phi::DefaultKernelSignatureMap::Instance().Insert( phi::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature())); op_type, std::move(maker.GetKernelSignature()));
} }
......
...@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, ...@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
class KernelArgsNameMaker { class KernelArgsNameMaker {
public: public:
virtual ~KernelArgsNameMaker() {} virtual ~KernelArgsNameMaker() {}
virtual const paddle::SmallVector<std::string>& GetInputArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetOutputArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0; virtual const paddle::SmallVector<const char*>& GetAttrsArgsNames() = 0;
}; };
void InitDefaultKernelSignatureMap(); void InitDefaultKernelSignatureMap();
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/small_vector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -106,10 +108,10 @@ class InferShapeContext { ...@@ -106,10 +108,10 @@ class InferShapeContext {
virtual bool IsRunMKLDNNKernel() const = 0; virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs( virtual paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string &name) const = 0; GetInputVarPtrs(const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs( virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string &name) const = 0; GetOutputVarPtrs(const std::string &name) const = 0;
protected: protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
......
...@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
} }
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs( paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
const std::string& name) const override { GetInputVarPtrs(const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
res;
auto it = var_map_in_->find(name); auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_map_in_->end(), it, var_map_in_->end(),
...@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return res; return res;
} }
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs( paddle::SmallVector<framework::InferShapeVarPtr, phi::kOutputSmallVectorSize>
const std::string& name) const override { GetOutputVarPtrs(const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res; paddle::SmallVector<framework::InferShapeVarPtr,
phi::kOutputSmallVectorSize>
res;
auto it = var_map_out_->find(name); auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_map_out_->end(), it, var_map_out_->end(),
......
...@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel); ...@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static const phi::Kernel empty_kernel;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper( const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) { const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar(); return var->SharedVar();
...@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
ctx_(ctx), ctx_(ctx),
kernel_type_(kernel_type), kernel_type_(kernel_type),
func_(func), func_(func),
dev_ctx_(dev_ctx) {} dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, const phi::Kernel& pt_kernel,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), : op_(op),
...@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr), func_(nullptr),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
run_phi_kernel_(true), run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature), pt_kernel_signature_(std::move(kernel_signature)),
pt_kernel_(pt_kernel) {} pt_kernel_(pt_kernel) {}
template <typename VarType> template <typename VarType>
...@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx); pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature; VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
...@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key = phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key); TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid( if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
pt_kernel_name, try_pt_kernel_key)) { try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type; expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key; << expected_kernel_key;
...@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_key); pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid() if (pt_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
...@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
// TODO(chenweihang): using CPUKernel when miss device kernel case return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, std::move(pt_kernel_signature), pt_kernel, dev_ctx);
pt_kernel, dev_ctx);
} else { } else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found."; << "` not found.";
...@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op); FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel( auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key); pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) { if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key << " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel; << " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace()); auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature, return PreparedOp(op, ctx, expected_kernel_key,
pt_cpu_kernel, cpu_ctx); std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
} }
} }
} }
...@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl( ...@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl(
#endif #endif
} }
// TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) { if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs); HandleComplexGradToRealGrad<VarType>(outs);
} }
......
...@@ -154,7 +154,7 @@ class PreparedOp { ...@@ -154,7 +154,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature, framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
...@@ -206,7 +206,7 @@ class PreparedOp { ...@@ -206,7 +206,7 @@ class PreparedOp {
bool run_phi_kernel_{false}; bool run_phi_kernel_{false};
bool run_kp_kernel_{false}; bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_; framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_; const phi::Kernel& pt_kernel_;
}; };
const inline framework::Attribute& GetAttr( const inline framework::Attribute& GetAttr(
...@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext( ...@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext(
} }
} }
auto ins_vector = it->second; auto& ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size(); size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
...@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel, ...@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) { for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var = ins_vector[offset]; auto& var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var->Var()); const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) { if (tensor_in && tensor_in->IsInitialized()) {
if (in_def.backend == phi::Backend::ALL_BACKEND) { if (in_def.backend == phi::Backend::ALL_BACKEND) {
......
...@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope( ...@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope(
status_is_cloned_ = true; status_is_cloned_ = true;
} else { } else {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
// TODO(wilber): we need to release memory occupied by weights. // TODO(wilber): we need to release memory occupied by weights.
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
status_is_cloned_ = false; status_is_cloned_ = false;
......
...@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init( ...@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init(
"The sub_scope should not be nullptr.")); "The sub_scope should not be nullptr."));
} else { } else {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
} }
......
...@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs); ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs)); ctx->HasInputs(framework::GradVarName(kOutputs));
auto pg_ig_names = ctx->Outputs(kXGRAD); auto pg_ig_names = ctx->Outputs(kXGRAD);
std::vector<framework::InferShapeVarPtr> in_var_ptrs = auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
ctx->GetInputVarPtrs(kX); auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(), PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of Inputs(X) must be the same as " "The size of Inputs(X) must be the same as "
......
...@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
context->ShareLoD("MultiLevelRois", "FpnRois"); context->ShareLoD("MultiLevelRois", "FpnRois");
} }
if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) { if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs = auto roi_inputs = context->GetInputVarPtrs("MultiLevelRois");
context->GetInputVarPtrs("MultiLevelRois"); auto score_inputs = context->GetInputVarPtrs("MultiLevelScores");
std::vector<framework::InferShapeVarPtr> score_inputs =
context->GetInputVarPtrs("MultiLevelScores");
for (size_t i = 0; i < roi_inputs.size(); ++i) { for (size_t i = 0; i < roi_inputs.size(); ++i) {
framework::Variable *roi_var = framework::Variable *roi_var =
BOOST_GET(framework::Variable *, roi_inputs[i]); BOOST_GET(framework::Variable *, roi_inputs[i]);
......
...@@ -60,6 +60,7 @@ limitations under the License. */ ...@@ -60,6 +60,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/uva_utils.h" #include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h" #include "paddle/phi/core/compat/type_defs.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) { ...@@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) {
*(imperative::AmpOperators::Instance().GetMutableAllowOps()), *(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps())); *(imperative::AmpOperators::Instance().GetMutableBlockOps()));
}) })
.def("_get_kernel_signature", .def(
[](imperative::Tracer &self, const std::string &type, "_get_kernel_signature",
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, [](imperative::Tracer &self, const std::string &type,
framework::AttributeMap attrs) { const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
// TODO(xiongkun): move this function outside of tracer. framework::AttributeMap attrs) {
auto ins_map = ConvertToNameTensorMap(ins); // TODO(xiongkun): move this function outside of tracer.
auto outs_map = ConvertToNameTensorMap(outs); auto ins_map = ConvertToNameTensorMap(ins);
{ auto outs_map = ConvertToNameTensorMap(outs);
auto to_vector = [](paddle::SmallVector<std::string> &vec) { {
return std::vector<std::string>(vec.begin(), vec.end()); auto input_to_vector =
}; [](paddle::SmallVector<const char *> &vec) {
auto ret = self.GetExpectedKernelSignature(type, ins_map, return std::vector<std::string>(vec.begin(), vec.end());
outs_map, attrs); };
auto kernelsig_ins = to_vector(std::get<0>(ret.args)); auto output_to_vector =
auto kernelsig_attrs = to_vector(std::get<1>(ret.args)); [](paddle::SmallVector<const char *> &vec) {
auto kernelsig_outs = to_vector(std::get<2>(ret.args)); return std::vector<std::string>(vec.begin(), vec.end());
return std::make_tuple(kernelsig_ins, kernelsig_attrs, };
kernelsig_outs); auto attr_to_vector = [](paddle::SmallVector<const char *> &vec) {
} return std::vector<std::string>(vec.begin(), vec.end());
}) };
auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs);
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args));
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args));
return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs);
}
})
.def("trace", .def("trace",
[](imperative::Tracer &self, const std::string &type, [](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
...@@ -2907,6 +2907,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2907,6 +2907,8 @@ All parameter, weight, gradient are variables in Paddle.
framework::LoadOpMetaInfoAndRegisterOp(dso_name)); framework::LoadOpMetaInfoAndRegisterOp(dso_name));
}); });
m.def("init_devices", []() { framework::InitDevices(); }); m.def("init_devices", []() { framework::InitDevices(); });
m.def("init_default_kernel_signatures",
[]() { framework::InitDefaultKernelSignatureMap(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_ascend", IsCompiledWithAscend); m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM); m.def("is_compiled_with_rocm", IsCompiledWithROCM);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h" #include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/infrt/dialect/phi/data_type.h" #include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
namespace infrt { namespace infrt {
...@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels( ...@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels(
phi_kernel_desc.input_types.clear(); phi_kernel_desc.input_types.clear();
phi_kernel_desc.output_types.clear(); phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def(); phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg = const paddle::SmallVector<phi::TensorArgDef, phi::kInputSmallVectorSize>&
args_def.input_defs(); input_arg = args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg = const paddle::SmallVector<phi::TensorArgDef, phi::kOutputSmallVectorSize>&
args_def.output_defs(); output_arg = args_def.output_defs();
for (auto tensor_arg : input_arg) { for (auto tensor_arg : input_arg) {
phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg)); phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
} }
......
...@@ -91,6 +91,7 @@ using ValueVariantType = ...@@ -91,6 +91,7 @@ using ValueVariantType =
std::vector<::phi::DenseTensor*>, std::vector<::phi::DenseTensor*>,
paddle::experimental::ScalarBase<::phi::DenseTensor>, paddle::experimental::ScalarBase<::phi::DenseTensor>,
paddle::experimental::IntArrayBase<::phi::DenseTensor>, paddle::experimental::IntArrayBase<::phi::DenseTensor>,
std::vector<const ::phi::MetaTensor*>,
std::vector<::phi::MetaTensor*>, std::vector<::phi::MetaTensor*>,
::phi::MetaConfig, ::phi::MetaConfig,
paddle::experimental::Backend, paddle::experimental::Backend,
......
...@@ -271,10 +271,10 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -271,10 +271,10 @@ std::vector<Tensor> split_impl(const Tensor& x,
// Calculate the number of out tensors // Calculate the number of out tensors
size_t out_number; size_t out_number;
if (num_or_sections.GetData().size() == 1) { if (num_or_sections.size() == 1) {
out_number = num_or_sections.GetData()[0]; out_number = num_or_sections.GetData()[0];
} else { } else {
out_number = num_or_sections.GetData().size(); out_number = num_or_sections.size();
} }
std::vector<Tensor> out; std::vector<Tensor> out;
...@@ -449,54 +449,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -449,54 +449,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output; return api_output;
} }
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "unbind API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
// Calculate the number of out tensors
auto input_shape = input.dims();
if (axis < 0) {
axis = input_shape.size() + axis;
}
auto out_num = input_shape[axis];
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_num);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_num);
for (int64_t i = 0; i < out_num; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
int,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
return out;
}
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level // TODO(chenweihang): the original sum grad op can support higher-level
...@@ -674,71 +626,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl( ...@@ -674,71 +626,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
return api_output; return api_output;
} }
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"concat_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "concat_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "concat_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
// std::unique_ptr<std::vector<phi::DenseTensor>>
auto dense_x = PrepareData(x, kernel.InputAt(0), {});
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
// Calculate the number of out tensors
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
std::vector<phi::MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_ptrs;
meta_x_ptrs.reserve(x.size());
for (const auto& t : *dense_x) {
meta_x.push_back(t);
meta_x_ptrs.push_back(&meta_x.back());
}
std::vector<phi::MetaTensor> meta_x_grad;
meta_x_grad.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
meta_x_grad_ptrs.reserve(x.size());
for (size_t i = 0; i < out_number; ++i) {
meta_x_grad.push_back(*dense_x_grad[i]);
meta_x_grad_ptrs.push_back(&meta_x_grad.back());
}
phi::UnchangedMultiInferMeta(meta_x_ptrs, meta_x_grad_ptrs);
std::vector<const phi::DenseTensor*> dense_x_ptr;
dense_x_ptr.reserve(x.size());
for (const auto& t : *dense_x) {
dense_x_ptr.push_back(&t);
}
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const phi::DenseTensor&,
const phi::Scalar&,
std::vector<phi::DenseTensor*>);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx, dense_x_ptr, *dense_out_grad, phi::Scalar(axis), dense_x_grad);
return x_grad;
}
Tensor imag_grad_impl(const Tensor& out_grad) { Tensor imag_grad_impl(const Tensor& out_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad), phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(), out_grad.layout(),
...@@ -795,328 +682,5 @@ Tensor real_grad_impl(const Tensor& out_grad) { ...@@ -795,328 +682,5 @@ Tensor real_grad_impl(const Tensor& out_grad) {
return out; return out;
} }
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"stack_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "stack_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "stack_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {});
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
std::vector<phi::MetaTensor> meta_x_grad;
meta_x_grad.reserve(out_number);
std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
meta_x_grad_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_x_grad.push_back(dense_x_grad[i]);
meta_x_grad_ptrs.push_back(&meta_x_grad.back());
}
phi::StackGradInferMeta(
MakeMetaTensor(*dense_out_grad), axis, meta_x_grad_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
int axis,
std::vector<phi::DenseTensor*>);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, axis, dense_x_grad);
return x_grad;
}
std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}
auto x_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
inputs_metas[i] = &x_meta_vec[i];
}
// Calculate the number of out tensors
size_t out_number = inputs.size();
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MeshgridInferMeta(inputs_metas, meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, dense_outs);
return out;
}
std::vector<Tensor> meshgrid_grad_impl(
const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs, outputs_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}
auto input_outputs_grad_vec =
PrepareData(outputs_grad, kernel.InputAt(1), {});
std::vector<const phi::DenseTensor*> input_outputs_grad(
input_outputs_grad_vec->size());
for (size_t i = 0; i < input_outputs_grad.size(); ++i) {
input_outputs_grad[i] = &input_outputs_grad_vec->at(i);
}
size_t out_number = inputs.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
auto inputs_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(inputs_meta_vec.size());
for (size_t i = 0; i < inputs_meta_vec.size(); ++i) {
inputs_metas[i] = &inputs_meta_vec[i];
}
auto outputs_grad_meta_vec = MakeMetaTensor(input_outputs_grad);
std::vector<phi::MetaTensor*> outputs_grad_metas(
outputs_grad_meta_vec.size());
for (size_t i = 0; i < outputs_grad_meta_vec.size(); ++i) {
outputs_grad_metas[i] = &outputs_grad_meta_vec[i];
}
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MeshgridGradInferMeta(inputs_metas, outputs_grad_metas, meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, input_outputs_grad, kernel_out);
return api_output;
}
std::vector<Tensor> multi_dot_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x, out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
VLOG(6) << "multi_dot_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"multi_dot_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "multi_dot_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_x_vec = PrepareData(x, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_x(input_x_vec->size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = &input_x_vec->at(i);
}
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
size_t out_number = input_x.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MultiDotGradInferMeta(
x_metas, MakeMetaTensor(*input_out_grad), meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const phi::DenseTensor&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, *input_out_grad, kernel_out);
return api_output;
}
std::vector<Tensor> multiplex_grad_impl(const std::vector<Tensor>& inputs,
const Tensor& ids,
const Tensor& out_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
VLOG(6) << "multiplex_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"multiplex_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "multiplex_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_ids = PrepareData(ids, kernel.InputAt(0), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
auto out_number = inputs.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MultiplexGradInferMeta(MakeMetaTensor(*input_ids),
MakeMetaTensor(*input_out_grad),
meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *input_ids, *input_out_grad, kernel_out);
return api_output;
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -30,6 +30,20 @@ namespace experimental { ...@@ -30,6 +30,20 @@ namespace experimental {
////////////////// Forward api impls ////////////////////// ////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
Tensor conv2d_impl(const Tensor& input, Tensor conv2d_impl(const Tensor& input,
const Tensor& filter, const Tensor& filter,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -62,8 +76,6 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -62,8 +76,6 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections, const IntArray& num_or_sections,
const Scalar& axis); const Scalar& axis);
std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs);
std::tuple<Tensor, Tensor, Tensor> momentum_impl( std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param, const Tensor& param,
const Tensor& grad, const Tensor& grad,
...@@ -77,49 +89,14 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl( ...@@ -77,49 +89,14 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision, bool multi_precision,
float rescale_grad); float rescale_grad);
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
////////////////// Backward(grad) api impls ////////////////////// ////////////////// Backward(grad) api impls //////////////////////
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x, std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad); const Tensor& out_grad);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
/************************ backward api impl ***************************/
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis);
Tensor imag_grad_impl(const Tensor& x); Tensor imag_grad_impl(const Tensor& x);
Tensor real_grad_impl(const Tensor& x); Tensor real_grad_impl(const Tensor& x);
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis);
std::vector<Tensor> meshgrid_grad_impl(const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad);
std::vector<Tensor> multi_dot_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad);
std::vector<Tensor> multiplex_grad_impl(const std::vector<Tensor>& inputs,
const Tensor& ids,
const Tensor& out_grad);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -76,6 +76,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -76,6 +76,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors; return meta_tensors;
} }
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) { phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor); return phi::MetaTensor(tensor);
} }
......
...@@ -53,6 +53,9 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor); ...@@ -53,6 +53,9 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor);
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::DenseTensor*>& tensors); const std::vector<const phi::DenseTensor*>& tensors);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor); phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor);
......
...@@ -96,6 +96,8 @@ class IntArrayBase { ...@@ -96,6 +96,8 @@ class IntArrayBase {
template <typename OtherT> template <typename OtherT>
IntArrayBase(const IntArrayBase<OtherT>& other) : array_(other.GetData()) {} IntArrayBase(const IntArrayBase<OtherT>& other) : array_(other.GetData()) {}
size_t size() const { return array_.size(); }
const std::vector<int64_t>& GetData() const { return array_; } const std::vector<int64_t>& GetData() const { return array_; }
private: private:
......
...@@ -19,45 +19,33 @@ limitations under the License. */ ...@@ -19,45 +19,33 @@ limitations under the License. */
#include <tuple> #include <tuple>
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
inline std::string GradVarName(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
// tuple(input_names, attr_names, output_names) // tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>, using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
paddle::SmallVector<std::string>, paddle::SmallVector<const char*>,
paddle::SmallVector<std::string>>; paddle::SmallVector<const char*>>;
struct KernelSignature { struct KernelSignature {
std::string name; const char* name;
KernelArgsTuple args; KernelArgsTuple args;
KernelSignature() = default; KernelSignature() = default;
KernelSignature(std::string&& kernel_name, KernelSignature(const char* kernel_name,
paddle::SmallVector<std::string>&& inputs, paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<std::string>&& attrs, paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<std::string>&& outputs) paddle::SmallVector<const char*>&& outputs)
: name(std::move(kernel_name)), : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
args(std::make_tuple(inputs, attrs, outputs)) {} KernelSignature(const char* kernel_name,
KernelSignature(const std::string& kernel_name, const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<std::string>& inputs, const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<std::string>& attrs, const paddle::SmallVector<const char*>& outputs)
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {} : name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
......
...@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { ...@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
} }
} }
std::string TransToPhiKernelName(const std::string& fluid_op_name) { const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name); return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
} }
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace phi { namespace phi {
std::string TransToPhiKernelName(const std::string& fluid_op_name); const std::string& TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& phi_kernel_name); const std::string& TransToFluidOpName(const std::string& phi_kernel_name);
Backend TransToPhiBackend(const phi::Place& place); Backend TransToPhiBackend(const phi::Place& place);
......
...@@ -26,6 +26,8 @@ limitations under the License. */ ...@@ -26,6 +26,8 @@ limitations under the License. */
namespace phi { namespace phi {
const static std::string deprecated_kernel_name = "deprecated"; // NOLINT
const std::unordered_set<std::string> standard_kernel_suffixs({ const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel "sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op "raw" // fallback kernel of origfinal fluid op
...@@ -134,9 +136,9 @@ class OpUtilsMap { ...@@ -134,9 +136,9 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
} }
std::string GetBaseKernelName(const std::string& op_type) const { const std::string& GetBaseKernelName(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) { if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return "deprecated"; return deprecated_kernel_name;
} }
auto it = base_kernel_name_map_.find(op_type); auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) { if (it == base_kernel_name_map_.end()) {
...@@ -150,7 +152,7 @@ class OpUtilsMap { ...@@ -150,7 +152,7 @@ class OpUtilsMap {
auto it = arg_mapping_fn_map_.find(op_type); auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) { if (it == arg_mapping_fn_map_.end()) {
auto func = auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature { [&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type); return DefaultKernelSignatureMap::Instance().Get(op_type);
}; };
return func; return func;
......
...@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) { ...@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) {
config_ = std::move(config); config_ = std::move(config);
} }
void InferMetaContext::EmplaceBackInput( void InferMetaContext::EmplaceBackInput(MetaTensor input) {
std::shared_ptr<phi::MetaTensor> input) {
int index = inputs_.size(); int index = inputs_.size();
inputs_.emplace_back(std::move(input)); inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1)); input_range_.emplace_back(std::pair<int, int>(index, index + 1));
} }
void InferMetaContext::EmplaceBackOutput( void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
std::shared_ptr<phi::MetaTensor> output) {
int index = outputs_.size(); int index = outputs_.size();
outputs_.emplace_back(std::move(output)); outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1)); output_range_.emplace_back(std::pair<int, int>(index, index + 1));
...@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) { ...@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
} }
void InferMetaContext::EmplaceBackInputs( void InferMetaContext::EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs) { paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = inputs_.size(); int index = inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size())); input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(), inputs_.insert(inputs_.end(),
...@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs( ...@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end())); std::make_move_iterator(inputs.end()));
} }
void InferMetaContext::EmplaceBackOutputs( void InferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs) { paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs) {
int index = outputs_.size(); int index = outputs_.size();
output_range_.emplace_back( output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size())); std::pair<int, int>(index, index + outputs.size()));
...@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const { ...@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const {
const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const { const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx); return inputs_.at(idx);
} }
paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt( paddle::optional<const MetaTensor&> InferMetaContext::OptionalInputAt(
size_t idx) const { size_t idx) const {
const auto& input = inputs_.at(idx); const auto& input = inputs_.at(idx);
return input ? paddle::optional<const phi::MetaTensor&>{static_cast< return input.initialized()
const phi::MetaTensor&>(*input)} ? paddle::optional<const MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none}; : paddle::optional<const MetaTensor&>{paddle::none};
} }
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start, std::vector<const MetaTensor*> InferMetaContext::InputsBetween(
size_t end) const { size_t start, size_t end) const {
std::vector<MetaTensor*> result; std::vector<const MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get()); auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
} }
return result; return result;
...@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>> ...@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start); const auto& first = inputs_.at(start);
if (first) { if (first.initialized()) {
std::vector<const MetaTensor*> result; std::vector<const MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get()); auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
} }
return paddle::optional<const std::vector<const MetaTensor*>>(result); return paddle::optional<const std::vector<const MetaTensor*>>(result);
...@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { ...@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
} }
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get(); auto& out = outputs_.at(idx);
return out.initialized() ? &out : nullptr;
} }
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start, std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
...@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start, ...@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
std::vector<MetaTensor*> result; std::vector<MetaTensor*> result;
result.reserve(end - start); result.reserve(end - start);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
result.emplace_back(outputs_.at(i).get()); auto& out = outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
} }
return result; return result;
} }
......
...@@ -37,28 +37,28 @@ class InferMetaContext { ...@@ -37,28 +37,28 @@ class InferMetaContext {
explicit InferMetaContext(MetaConfig config) : config_(config) {} explicit InferMetaContext(MetaConfig config) : config_(config) {}
void SetMetaConfig(MetaConfig config); void SetMetaConfig(MetaConfig config);
void EmplaceBackInput(std::shared_ptr<phi::MetaTensor> input); const MetaConfig& GetMetaConfig() const;
void EmplaceBackOutput(std::shared_ptr<phi::MetaTensor> output);
void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr); void EmplaceBackAttr(paddle::any attr);
void EmplaceBackInputs( void EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs); paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs( void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs); paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs);
const std::pair<int, int>& InputRangeAt(size_t idx) const; virtual const MetaTensor& InputAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const; virtual paddle::optional<const MetaTensor&> OptionalInputAt(size_t idx) const;
const MetaConfig& GetMetaConfig() const; virtual std::vector<const MetaTensor*> InputsBetween(size_t start,
size_t end) const;
const MetaTensor& InputAt(size_t idx) const; virtual paddle::optional<const std::vector<const MetaTensor*>>
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const; OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx); virtual MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end); virtual std::vector<MetaTensor*> MutableOutputBetween(size_t start,
size_t end);
template <typename AttrType> template <typename AttrType>
AttrType AttrAt(size_t idx) { AttrType AttrAt(size_t idx) {
...@@ -73,19 +73,24 @@ class InferMetaContext { ...@@ -73,19 +73,24 @@ class InferMetaContext {
} }
} }
private: const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual ~InferMetaContext() = default;
protected:
MetaConfig config_; MetaConfig config_;
// NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
// objects are all created in each round, so we have to use smart pointer
// here, maybe we can implemented a new InferMetaContext and a series utils
// specifically for fluid to avoid using shared_ptr
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs_;
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs_;
paddle::SmallVector<paddle::any> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_; paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
paddle::SmallVector<std::pair<int, int>> output_range_; input_range_;
paddle::SmallVector<std::pair<int, int>, phi::kOutputSmallVectorSize>
output_range_;
private:
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs_;
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs_;
}; };
#define PD_INFER_META(...) \ #define PD_INFER_META(...) \
...@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}; };
template <typename... Tail> template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> { struct InferMetaFnCallHelper<const std::vector<const MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0, static_assert(attr_idx == 0,
...@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static_assert(out_idx == 0, static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs."); "InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor*> arg = std::vector<const MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second); ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper< InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx, Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
......
...@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) { ...@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) {
void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) { void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) { if (idx < input_range_.size()) {
input_range_[idx] = range; input_range_[idx] = std::move(range);
} else if (idx == input_range_.size()) { } else if (idx == input_range_.size()) {
input_range_.emplace_back(range); input_range_.emplace_back(range);
} else { } else {
...@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) { ...@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) { void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) { if (idx < output_range_.size()) {
output_range_[idx] = range; output_range_[idx] = std::move(range);
} else if (idx == output_range_.size()) { } else if (idx == output_range_.size()) {
output_range_.emplace_back(range); output_range_.emplace_back(range);
} else { } else {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
namespace phi { namespace phi {
const static Kernel empty_kernel; // NOLINT
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const { uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0; uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---| // |----31-20------|---19-12---|---11-8----|---7-0---|
...@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() { ...@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory; return g_op_kernel_factory;
} }
Kernel KernelFactory::SelectKernel(const std::string& kernel_name, const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) { if (iter == kernels_.end()) {
return Kernel(); return empty_kernel;
} }
auto kernel_iter = iter->second.find(kernel_key); auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) { if (kernel_iter == iter->second.end()) {
return Kernel(); return empty_kernel;
} }
return kernel_iter->second; return kernel_iter->second;
} }
...@@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap( ...@@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second; return iter->second;
} }
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, bool KernelFactory::HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
...@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( ...@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype)); KernelKey(backend, layout, dtype));
} }
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
return iter->second.cbegin()->second.args_def();
}
// print kernel info with json format: // print kernel info with json format:
// { // {
// "(CPU, Undefined(AnyLayout), complex64)": { // "(CPU, Undefined(AnyLayout), complex64)": {
......
...@@ -151,30 +151,38 @@ class KernelArgsDef { ...@@ -151,30 +151,38 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index)); attribute_defs_.emplace_back(AttributeArgDef(type_index));
} }
const paddle::SmallVector<TensorArgDef>& input_defs() const { const paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs()
const {
return input_defs_; return input_defs_;
} }
const paddle::SmallVector<TensorArgDef>& output_defs() const { const paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs()
const {
return output_defs_; return output_defs_;
} }
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const { const paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>&
attribute_defs() const {
return attribute_defs_; return attribute_defs_;
} }
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; } paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
return input_defs_;
}
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; } paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
return output_defs_;
}
paddle::SmallVector<AttributeArgDef>& attribute_defs() { paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>& attribute_defs() {
return attribute_defs_; return attribute_defs_;
} }
private: private:
paddle::SmallVector<TensorArgDef> input_defs_{{}}; paddle::SmallVector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}}; paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}}; paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
{}};
}; };
class Kernel { class Kernel {
...@@ -209,7 +217,7 @@ class Kernel { ...@@ -209,7 +217,7 @@ class Kernel {
TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); } TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }
bool IsValid() { return fn_ != nullptr; } bool IsValid() const { return fn_ != nullptr; }
private: private:
KernelFn fn_{nullptr}; KernelFn fn_{nullptr};
...@@ -246,14 +254,17 @@ class KernelFactory { ...@@ -246,14 +254,17 @@ class KernelFactory {
DataLayout layout, DataLayout layout,
DataType dtype) const; DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name, bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name, const Kernel& SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
KernelKeyMap SelectKernelMap(const std::string& kernel_name) const; KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const;
private: private:
KernelFactory() = default; KernelFactory() = default;
......
...@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) { ...@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
} }
} }
bool MetaTensor::initialized() const { return tensor_ != nullptr; }
} // namespace phi } // namespace phi
...@@ -45,10 +45,10 @@ class MetaTensor { ...@@ -45,10 +45,10 @@ class MetaTensor {
: tensor_(const_cast<TensorBase*>(&tensor)) {} : tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default; MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete; MetaTensor& operator=(MetaTensor&&) = default;
MetaTensor& operator=(MetaTensor&&) = delete; MetaTensor(const MetaTensor&) = default;
MetaTensor& operator=(const MetaTensor&) = default;
virtual ~MetaTensor() = default; virtual ~MetaTensor() = default;
...@@ -64,6 +64,8 @@ class MetaTensor { ...@@ -64,6 +64,8 @@ class MetaTensor {
virtual void share_meta(const MetaTensor& meta_tensor); virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor); virtual void share_dims(const MetaTensor& meta_tensor);
virtual bool initialized() const;
private: private:
// Because the lod in compiletime and runtime is different, // Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods // so `LoD` cannot in public methods
......
...@@ -22,7 +22,7 @@ class Kernel; ...@@ -22,7 +22,7 @@ class Kernel;
class KernelKey; class KernelKey;
class KernelArgsDef; class KernelArgsDef;
class KernelContext; class KernelContext;
class KernelSignature; struct KernelSignature;
class ArgumentMappingContext; class ArgumentMappingContext;
class InferMetaContext; class InferMetaContext;
...@@ -35,4 +35,9 @@ using ArgumentMappingFn = ...@@ -35,4 +35,9 @@ using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>; std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx); using InferMetaFn = void (*)(InferMetaContext* ctx);
// Global SmallVector size setting
constexpr size_t kInputSmallVectorSize = 10U;
constexpr size_t kAttrSmallVectorSize = 10U;
constexpr size_t kOutputSmallVectorSize = 5U;
} // namespace phi } // namespace phi
...@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x); dx->share_meta(x);
} }
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) { std::vector<MetaTensor*> inputs_grad) {
PADDLE_ENFORCE_GT(outputs_grad.size(), PADDLE_ENFORCE_GT(outputs_grad.size(),
1, 1,
...@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, ...@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
} }
} }
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x, void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad, const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad) { std::vector<MetaTensor*> x_grad) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive, bool adaptive,
MetaTensor* dx); MetaTensor* dx);
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad); std::vector<MetaTensor*> inputs_grad);
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x, void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad, const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad); std::vector<MetaTensor*> x_grad);
......
...@@ -21,7 +21,8 @@ limitations under the License. */ ...@@ -21,7 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi { namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) { std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors) {
std::vector<DDim> dims; std::vector<DDim> dims;
dims.reserve(tensors.size()); dims.reserve(tensors.size());
for (const MetaTensor* tensor : tensors) { for (const MetaTensor* tensor : tensors) {
...@@ -148,7 +149,7 @@ void AdamaxInferMeta(const MetaTensor& param, ...@@ -148,7 +149,7 @@ void AdamaxInferMeta(const MetaTensor& param,
inf_norm_out->set_dtype(inf_norm.dtype()); inf_norm_out->set_dtype(inf_norm.dtype());
} }
void AddNInferMeta(const std::vector<MetaTensor*>& x, void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
auto N = x.size(); auto N = x.size();
...@@ -511,7 +512,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, ...@@ -511,7 +512,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) { std::vector<MetaTensor*> out) {
int target_rank = 0; int target_rank = 0;
const auto& input_dims = GetMetaTensorsDim(x); const auto& input_dims = GetMetaTensorsDim(x);
...@@ -565,7 +566,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, ...@@ -565,7 +566,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
} }
} }
void ConcatInferMeta(const std::vector<MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
MetaConfig config) { MetaConfig config) {
...@@ -1357,7 +1358,7 @@ void InterpolateInferMeta( ...@@ -1357,7 +1358,7 @@ void InterpolateInferMeta(
} }
} }
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) { std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size(); const size_t inputs_num = inputs.size();
...@@ -1420,7 +1421,8 @@ void MomentumInferMeta(const MetaTensor& param, ...@@ -1420,7 +1421,8 @@ void MomentumInferMeta(const MetaTensor& param,
} }
} }
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x); auto inputs_dims = GetMetaTensorsDim(x);
const size_t inputs_num = inputs_dims.size(); const size_t inputs_num = inputs_dims.size();
...@@ -1493,7 +1495,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) { ...@@ -1493,7 +1495,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins, void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids, const MetaTensor& ids,
MetaTensor* out) { MetaTensor* out) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -1672,8 +1674,8 @@ void RmspropInferMeta(const MetaTensor& param, ...@@ -1672,8 +1674,8 @@ void RmspropInferMeta(const MetaTensor& param,
} }
void RnnInferMeta(const MetaTensor& x, void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state, const std::vector<const MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list, const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length, paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob, float dropout_prob,
bool is_bidirec, bool is_bidirec,
...@@ -1779,7 +1781,7 @@ void SGDInferMeta(const MetaTensor& param, ...@@ -1779,7 +1781,7 @@ void SGDInferMeta(const MetaTensor& param,
param_out->set_dtype(param.dtype()); param_out->set_dtype(param.dtype());
} }
void StackInferMeta(const std::vector<MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out) { MetaTensor* out) {
PADDLE_ENFORCE_GT(x.size(), PADDLE_ENFORCE_GT(x.size(),
...@@ -1825,7 +1827,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x, ...@@ -1825,7 +1827,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) { std::vector<MetaTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) { for (size_t i = 0; i < x.size(); ++i) {
out[i]->share_meta(*x[i]); out[i]->share_meta(*x[i]);
......
...@@ -35,7 +35,8 @@ namespace phi { ...@@ -35,7 +35,8 @@ namespace phi {
// //
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order. // NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors); std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors);
void AdadeltaInferMeta(const MetaTensor& param, void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad, const MetaTensor& grad,
...@@ -68,7 +69,7 @@ void AdamaxInferMeta(const MetaTensor& param, ...@@ -68,7 +69,7 @@ void AdamaxInferMeta(const MetaTensor& param,
MetaTensor* moment_out, MetaTensor* moment_out,
MetaTensor* inf_norm_out); MetaTensor* inf_norm_out);
void AddNInferMeta(const std::vector<MetaTensor*>& x, void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
...@@ -124,10 +125,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x, ...@@ -124,10 +125,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x, void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void ConcatInferMeta(const std::vector<MetaTensor*>& x, void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
...@@ -178,7 +179,7 @@ void InterpolateInferMeta( ...@@ -178,7 +179,7 @@ void InterpolateInferMeta(
MetaTensor* output, MetaTensor* output,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs); std::vector<MetaTensor*> outputs);
void MomentumInferMeta(const MetaTensor& param, void MomentumInferMeta(const MetaTensor& param,
...@@ -196,9 +197,10 @@ void MomentumInferMeta(const MetaTensor& param, ...@@ -196,9 +197,10 @@ void MomentumInferMeta(const MetaTensor& param,
MetaTensor* velocity_out, MetaTensor* velocity_out,
MetaTensor* master_param_out); MetaTensor* master_param_out);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out); void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins, void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids, const MetaTensor& ids,
MetaTensor* out); MetaTensor* out);
...@@ -227,8 +229,8 @@ void RmspropInferMeta(const MetaTensor& param, ...@@ -227,8 +229,8 @@ void RmspropInferMeta(const MetaTensor& param,
MetaTensor* mean_grad_out); MetaTensor* mean_grad_out);
void RnnInferMeta(const MetaTensor& x, void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state, const std::vector<const MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list, const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length, paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob, float dropout_prob,
bool is_bidirec, bool is_bidirec,
...@@ -251,11 +253,11 @@ void SGDInferMeta(const MetaTensor& param, ...@@ -251,11 +253,11 @@ void SGDInferMeta(const MetaTensor& param,
MetaTensor* param_out, MetaTensor* param_out,
MetaTensor* master_param_out); MetaTensor* master_param_out);
void StackInferMeta(const std::vector<MetaTensor*>& x, void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis, int axis,
MetaTensor* out); MetaTensor* out);
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x, void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out); std::vector<MetaTensor*> out);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
......
...@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx, ...@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx,
const Scalar& axis) { const Scalar& axis) {
std::vector<MetaTensor> meta_x; std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size()); meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr; std::vector<const MetaTensor*> meta_x_ptr;
for (const auto* t : x) { for (const auto* t : x) {
meta_x.emplace_back(*t); meta_x.emplace_back(*t);
meta_x_ptr.push_back(&meta_x.back()); meta_x_ptr.push_back(&meta_x.back());
......
...@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("abs_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
KernelSignature AbsDoubleGradOpArgumentMapping( KernelSignature AbsDoubleGradOpArgumentMapping(
......
...@@ -19,26 +19,22 @@ namespace phi { ...@@ -19,26 +19,22 @@ namespace phi {
#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature( \
{"X", GradVarName("Out")}, \ op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
{attrs}, \
{GradVarName("X")}); \
} }
#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature( \
{"Out", GradVarName("Out")}, \ op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
{attrs}, \
{GradVarName("X")}); \
} }
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature( \ return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \ op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"}); \
} }
#define comma , #define comma ,
...@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("logit_grad", {"X", "Out@GRAD"}, {"eps"}, {"X@GRAD"});
"logit_grad", {"X", GradVarName("Out")}, {"eps"}, {GradVarName("X")});
} }
KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu_grad", return KernelSignature(
{"X", "Out", GradVarName("Out")}, "elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"});
{"alpha"},
{GradVarName("X")});
} }
KernelSignature EluDoubleGradOpArgumentMapping( KernelSignature EluDoubleGradOpArgumentMapping(
...@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) { if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_grad", return KernelSignature(
{"X", GradVarName("Out")}, "pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"});
{"FactorTensor"},
{GradVarName("X")});
} else { } else {
return KernelSignature( return KernelSignature(
"pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")}); "pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"});
} }
} }
......
...@@ -17,11 +17,10 @@ ...@@ -17,11 +17,10 @@
namespace phi { namespace phi {
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("addmm_grad",
"addmm_grad", {"Input", "X", "Y", "Out@GRAD"},
{"Input", "X", "Y", GradVarName("Out")}, {"Alpha", "Beta"},
{"Alpha", "Beta"}, {"Input@GRAD", "X@GRAD", "Y@GRAD"});
{GradVarName("Input"), GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature ArgsortGradOpArgumentMapping( KernelSignature ArgsortGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("argsort_grad", return KernelSignature("argsort_grad",
{"Indices", "X", GradVarName("Out")}, {"Indices", "X", "Out@GRAD"},
{"axis", "descending"}, {"axis", "descending"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
namespace phi { namespace phi {
KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("atan2_grad", return KernelSignature(
{"X1", "X2", GradVarName("Out")}, "atan2_grad", {"X1", "X2", "Out@GRAD"}, {}, {"X1@GRAD", "X2@GRAD"});
{},
{GradVarName("X1"), GradVarName("X2")});
} }
} // namespace phi } // namespace phi
......
...@@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature BatchNormGradOpArgumentMapping( KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("batch_norm_grad",
"batch_norm_grad", {
{ "X",
"X", "Scale",
"Scale", "Bias",
"Bias", "Mean",
"Mean", "Variance",
"Variance", "SavedMean",
"SavedMean", "SavedVariance",
"SavedVariance", "ReserveSpace",
"ReserveSpace", "Y@GRAD",
GradVarName("Y"), },
}, {"momentum",
{"momentum", "epsilon",
"epsilon", "data_layout",
"data_layout", "is_test",
"is_test", "use_global_stats",
"use_global_stats", "trainable_statistics",
"trainable_statistics", "fuse_with_relu"},
"fuse_with_relu"}, {"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
} }
KernelSignature BatchNormGradGradOpArgumentMapping( KernelSignature BatchNormGradGradOpArgumentMapping(
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature BCELossGradOpArgumentMapping( KernelSignature BCELossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("bce_loss_grad", return KernelSignature(
{"X", "Label", GradVarName("Out")}, "bce_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping( ...@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping(
KernelSignature BilinearTensorProductGradOpArgumentMapping( KernelSignature BilinearTensorProductGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_tensor_product_grad", return KernelSignature("bilinear_tensor_product_grad",
{"X", "Y", "Weight", GradVarName("Out")}, {"X", "Y", "Weight", "Out@GRAD"},
{}, {},
{GradVarName("X"), {"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"});
GradVarName("Y"),
GradVarName("Weight"),
GradVarName("Bias")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,7 +19,7 @@ namespace phi { ...@@ -19,7 +19,7 @@ namespace phi {
KernelSignature BroadcastTensorsGradOpArgumentMapping( KernelSignature BroadcastTensorsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"broadcast_tensors_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); "broadcast_tensors_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CholeskyGradOpArgumentMapping( KernelSignature CholeskyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_grad", return KernelSignature(
{"Out", GradVarName("Out")}, "cholesky_grad", {"Out", "Out@GRAD"}, {"upper"}, {"X@GRAD"});
{"upper"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature CholeskySolveGradOpArgumentMapping( KernelSignature CholeskySolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_solve_grad", return KernelSignature("cholesky_solve_grad",
{"X", "Y", "Out", GradVarName("Out")}, {"X", "Y", "Out", "Out@GRAD"},
{"upper"}, {"upper"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace phi { namespace phi {
KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> attr_names; paddle::SmallVector<std::string, kAttrSmallVectorSize> attr_names;
attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min"); attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min");
attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max"); attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max");
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInput("X")) {
...@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Min")) { if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) { if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"});
{"Min", "Max"},
{GradVarName("X")});
} else { } else {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"});
{"Min", "max"},
{GradVarName("X")});
} }
} else { } else {
if (ctx.HasInput("Max")) { if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"});
{"min", "Max"},
{GradVarName("X")});
} else { } else {
return KernelSignature("clip_grad", return KernelSignature(
{"X", GradVarName("Out")}, "clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"});
{"min", "max"},
{GradVarName("X")});
} }
} }
} }
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
namespace phi { namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
"real_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
} }
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
"imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat_grad", return KernelSignature(
{"X", {GradVarName("Out")}}, "concat_grad", {"X", {"Out@GRAD"}}, {"AxisTensor"}, {{"X@GRAD"}});
{"AxisTensor"},
{{GradVarName("X")}});
} }
return KernelSignature("concat_grad", return KernelSignature(
{"X", {GradVarName("Out")}}, "concat_grad", {"X", {"Out@GRAD"}}, {"axis"}, {{"X@GRAD"}});
{"axis"},
{{GradVarName("X")}});
} }
} // namespace phi } // namespace phi
......
...@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad", return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto", "use_addto",
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search"}, "exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv2dDoubleGradOpArgumentMapping( KernelSignature Conv2dDoubleGradOpArgumentMapping(
......
...@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad", return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto", "use_addto",
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search"}, "exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv3dDoubleGradOpArgumentMapping( KernelSignature Conv3dDoubleGradOpArgumentMapping(
......
...@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping( ...@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping(
KernelSignature Conv2dTransposeGradOpArgumentMapping( KernelSignature Conv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_transpose_grad", return KernelSignature("conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping( ...@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping( KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping(
...@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping( ...@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping(
KernelSignature Conv3dTransposeGradOpArgumentMapping( KernelSignature Conv3dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("conv3d_transpose_grad", return KernelSignature("conv3d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping( ...@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
...@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping( ...@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_transpose_grad", return KernelSignature("depthwise_conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"output_padding", "output_padding",
...@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping( ...@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
"groups", "groups",
"dilations", "dilations",
"data_format"}, "data_format"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cross_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "cross_grad", {"X", "Y", "Out@GRAD"}, {"dim"}, {"X@GRAD", "Y@GRAD"});
{"dim"},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CumprodGradGradOpArgumentMapping( KernelSignature CumprodGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("cumprod_grad", return KernelSignature(
{"X", "Out", GradVarName("Out")}, "cumprod_grad", {"X", "Out", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
{"dim"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping( ...@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"deformable_conv_grad", "deformable_conv_grad",
{"Input", "Offset", "Filter", "Mask", GradVarName("Output")}, {"Input", "Offset", "Filter", "Mask", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"dilations", "dilations",
"deformable_groups", "deformable_groups",
"groups", "groups",
"im2col_step"}, "im2col_step"},
{GradVarName("Input"), {"Input@GRAD", "Offset@GRAD", "Filter@GRAD", "Mask@GRAD"});
GradVarName("Offset"),
GradVarName("Filter"),
GradVarName("Mask")});
} }
} // namespace phi } // namespace phi
......
...@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping( ...@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping(
KernelSignature DepthwiseConv2dGradOpArgumentMapping( KernelSignature DepthwiseConv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_grad", return KernelSignature("depthwise_conv2d_grad",
{"Input", "Filter", GradVarName("Output")}, {"Input", "Filter", "Output@GRAD"},
{"strides", {"strides",
"paddings", "paddings",
"padding_algorithm", "padding_algorithm",
...@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping( ...@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping(
"workspace_size_MB", "workspace_size_MB",
"exhaustive_search", "exhaustive_search",
"fuse_relu_before_depthwise_conv"}, "fuse_relu_before_depthwise_conv"},
{GradVarName("Input"), GradVarName("Filter")}); {"Input@GRAD", "Filter@GRAD"});
} }
KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping( KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping(
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature DeterminantGradOpArgumentMapping( KernelSignature DeterminantGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("determinant_grad", return KernelSignature(
{"Input", "Out", GradVarName("Out")}, "determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"});
{},
{GradVarName("Input")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")}); "diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature DiagonalGradOpArgumentMapping( KernelSignature DiagonalGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("diagonal_grad", return KernelSignature("diagonal_grad",
{"Input", GradVarName("Out")}, {"Input", "Out@GRAD"},
{"offset", "axis1", "axis2"}, {"offset", "axis1", "axis2"},
{GradVarName("Input")}); {"Input@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,8 +18,7 @@ namespace phi { ...@@ -18,8 +18,7 @@ namespace phi {
KernelSignature DigammaGradOpArgumentMapping( KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ limitations under the License. */ ...@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dist_grad", return KernelSignature(
{"X", "Y", "Out", GradVarName("Out")}, "dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"});
{"p"},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ limitations under the License. */ ...@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
{},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DropoutGradOpArgumentMapping( KernelSignature DropoutGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_grad", return KernelSignature("dropout_grad",
{"Mask", GradVarName("Out")}, {"Mask", "Out@GRAD"},
{"dropout_prob", "is_test", "dropout_implementation"}, {"dropout_prob", "is_test", "dropout_implementation"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,13 +17,11 @@ ...@@ -17,13 +17,11 @@
namespace phi { namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigh_grad", return KernelSignature(
{"Eigenvalues", "eigh_grad",
"Eigenvectors", {"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
GradVarName("Eigenvalues"), {},
GradVarName("Eigenvectors")}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping( ...@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping(
KernelSignature ElementwiseAddGradOpArgumentMapping( KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("add_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
...@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping( ...@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping( KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("subtract_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
...@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( ...@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
KernelSignature ElementwiseDivGradOpArgumentMapping( KernelSignature ElementwiseDivGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("divide_grad", return KernelSignature("divide_grad",
{"X", "Y", "Out", GradVarName("Out")}, {"X", "Y", "Out", "Out@GRAD"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
KernelSignature ElementwiseFMinGradOpArgumentMapping( KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
...@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( ...@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
return KernelSignature("divide_double_grad", return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"}, {"Y", "Out", "DX", "DDX", "DDY"},
{"axis"}, {"axis"},
{GradVarName("Y"), "DOut", "DDOut"}); {"Y@GRAD", "DOut", "DDOut"});
} }
KernelSignature ElementwiseMulGradOpArgumentMapping( KernelSignature ElementwiseMulGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseFMaxOpArgumentMapping( KernelSignature ElementwiseFMaxOpArgumentMapping(
...@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping( ...@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
KernelSignature ElementwiseFMaxGradOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmax_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
...@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( ...@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
return KernelSignature("multiply_double_grad", return KernelSignature("multiply_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"}, {"X", "Y", "DOut", "DDX", "DDY"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y"), "DDOut"}); {"X@GRAD", "Y@GRAD", "DDOut"});
} }
KernelSignature ElementwiseMulTripleGradOpArgumentMapping( KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
...@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping( ...@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping( KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("maximum_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseMinGradOpArgumentMapping( KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("minimum_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
{"axis"},
{GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwisePowGradOpArgumentMapping( KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad", return KernelSignature("elementwise_pow_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", "Out@GRAD"},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping( ...@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping(
if (ctx.IsDenseTensorInput("W")) { if (ctx.IsDenseTensorInput("W")) {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) { if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("embedding_sparse_grad", return KernelSignature("embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} else { } else {
return KernelSignature("embedding_grad", return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} }
} else { } else {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) { if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("sparse_weight_embedding_sparse_grad", return KernelSignature("sparse_weight_embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} else { } else {
return KernelSignature("sparse_weight_embedding_grad", return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", "Out@GRAD"},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {"W@GRAD"});
} }
} }
} }
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("erfinv_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"});
"erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandAsGradOpArgumentMapping( KernelSignature ExpandAsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_as_grad", {"X", "Out@GRAD"}, {"target_shape"}, {"X@GRAD"});
{"target_shape"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -39,20 +39,14 @@ KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -39,20 +39,14 @@ KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
} }
if (ctx.HasInput("Shape")) { if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"});
{"Shape"},
{GradVarName("X")});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) { } else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"expand_shapes_tensor"}, {"X@GRAD"});
{"expand_shapes_tensor"},
{GradVarName("X")});
} else { } else {
return KernelSignature("expand_grad", return KernelSignature(
{"X", GradVarName("Out")}, "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
{"shape"},
{GradVarName("X")});
} }
} }
......
...@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping( KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")}); "flatten_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping( ...@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping( KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("frobenius_norm_grad", return KernelSignature("frobenius_norm_grad",
{"X", "Out", GradVarName("Out")}, {"X", "Out", "Out@GRAD"},
{"dim", "keep_dim", "reduce_all"}, {"dim", "keep_dim", "reduce_all"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,25 +17,23 @@ ...@@ -17,25 +17,23 @@
namespace phi { namespace phi {
KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gather_nd_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_grad", return KernelSignature("scatter_grad",
{"Ids", "Updates", GradVarName("Out")}, {"Ids", "Updates", "Out@GRAD"},
{"overwrite"}, {"overwrite"},
{GradVarName("X"), GradVarName("Updates")}); {"X@GRAD", "Updates@GRAD"});
} }
KernelSignature ScatterNdAddGradArgumentMapping( KernelSignature ScatterNdAddGradArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_nd_add_grad", return KernelSignature("scatter_nd_add_grad",
{"Index", "Updates", GradVarName("Out")}, {"Index", "Updates", "Out@GRAD"},
{}, {},
{GradVarName("X"), GradVarName("Updates")}); {"X@GRAD", "Updates@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad", return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")}, {"X", "Index", "Out@GRAD"},
{"Axis", "overwrite"}, {"Axis", "overwrite"},
{GradVarName("X")}); {"X@GRAD"});
} else { } else {
return KernelSignature("gather_grad", return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")}, {"X", "Index", "Out@GRAD"},
{"axis", "overwrite"}, {"axis", "overwrite"},
{GradVarName("X")}); {"X@GRAD"});
} }
} }
......
...@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) {
} }
KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu_grad", return KernelSignature(
{"X", GradVarName("Out")}, "gelu_grad", {"X", "Out@GRAD"}, {"approximate"}, {"X@GRAD"});
{"approximate"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( ...@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"graph_send_recv_grad", "graph_send_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")}, {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"pool_type"}, {"pool_type"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping( ...@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping(
KernelSignature GridSamplerGradOpArgumentMapping( KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad", return KernelSignature("grid_sample_grad",
{"X", "Grid", GradVarName("Output")}, {"X", "Grid", "Output@GRAD"},
{"mode", "padding_mode", "align_corners"}, {"mode", "padding_mode", "align_corners"},
{GradVarName("X"), GradVarName("Grid")}); {"X@GRAD", "Grid@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxGradOpArgumentMapping( KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("gumbel_softmax_grad", return KernelSignature(
{"Out", GradVarName("Out")}, "gumbel_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
{"axis"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping( ...@@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping(
KernelSignature HierarchicalSigmoidGradOpArgumentMapping( KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput(GradVarName("W"))) { if (ctx.IsDenseTensorOutput("W@GRAD")) {
return KernelSignature( return KernelSignature("hierarchical_sigmoid_grad",
"hierarchical_sigmoid_grad", {"X",
{"X", "W",
"W", "Label",
"Label", "PathTable",
"PathTable", "PathCode",
"PathCode", "Bias",
"Bias", "PreOut",
"PreOut", "Out@GRAD"},
GradVarName("Out")}, {"num_classes",
{"num_classes", "remote_prefetch",
"remote_prefetch", "trainer_id",
"trainer_id", "height_sections",
"height_sections", "epmap",
"epmap", "table_names",
"table_names", "is_sparse"},
"is_sparse"}, {"X@GRAD", "W@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); } else if (ctx.IsSelectedRowsOutput("W@GRAD")) {
} else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) { return KernelSignature("hierarchical_sigmoid_grad_sr",
return KernelSignature( {"X",
"hierarchical_sigmoid_grad_sr", "W",
{"X", "Label",
"W", "PathTable",
"Label", "PathCode",
"PathTable", "Bias",
"PathCode", "PreOut",
"Bias", "Out@GRAD"},
"PreOut", {"num_classes",
GradVarName("Out")}, "remote_prefetch",
{"num_classes", "trainer_id",
"remote_prefetch", "height_sections",
"trainer_id", "epmap",
"height_sections", "table_names",
"epmap", "is_sparse"},
"table_names", {"X@GRAD", "W@GRAD", "Bias@GRAD"});
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else { } else {
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
...@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature HuberLossGradOpArgumentMapping( KernelSignature HuberLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("huber_loss_grad", return KernelSignature("huber_loss_grad",
{"Residual", GradVarName("Out")}, {"Residual", "Out@GRAD"},
{"delta"}, {"delta"},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping( KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
{},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping( KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad", return KernelSignature(
{"X", "Index", GradVarName("Out")}, "index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
{"dim"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping( ...@@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping(
KernelSignature BilinearInterpGradOpArgumentMapping( KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("bilinear_interp_v2_grad",
"bilinear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature NearestInterpGradOpArgumentMapping( KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("nearest_interp_v2_grad",
"nearest_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature TrilinearInterpGradOpArgumentMapping( KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("trilinear_interp_v2_grad",
"trilinear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature LinearInterpGradOpArgumentMapping( KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("linear_interp_v2_grad",
"linear_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
KernelSignature BicubicInterpGradOpArgumentMapping( KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("bicubic_interp_v2_grad",
"bicubic_interp_v2_grad", {"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")}, {"data_layout",
{"data_layout", "out_d",
"out_d", "out_h",
"out_h", "out_w",
"out_w", "scale",
"scale", "interp_method",
"interp_method", "align_corners",
"align_corners", "align_mode"},
"align_mode"}, {"X@GRAD"});
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -20,9 +20,9 @@ namespace phi { ...@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KLDivLossGradOpArgumentMapping( KernelSignature KLDivLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("kldiv_loss_grad", return KernelSignature("kldiv_loss_grad",
{"X", "Target", GradVarName("Loss")}, {"X", "Target", "Loss@GRAD"},
{"reduction"}, {"reduction"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
namespace phi { namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron_grad", return KernelSignature(
{"X", "Y", GradVarName("Out")}, "kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
{},
{GradVarName("X"), GradVarName("Y")});
} }
} // namespace phi } // namespace phi
......
...@@ -20,9 +20,9 @@ namespace phi { ...@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KthvalueGradOpArgumentMapping( KernelSignature KthvalueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("kthvalue_grad", return KernelSignature("kthvalue_grad",
{"X", "Indices", GradVarName("Out")}, {"X", "Indices", "Out@GRAD"},
{"k", "axis", "keepdim"}, {"k", "axis", "keepdim"},
{GradVarName("X")}); {"X@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping( ...@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping(
KernelSignature LabelSmoothGradOpArgumentMapping( KernelSignature LabelSmoothGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("label_smooth_grad", return KernelSignature(
{GradVarName("Out")}, "label_smooth_grad", {"Out@GRAD"}, {"epsilon"}, {"X@GRAD"});
{"epsilon"},
{GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LayerNormGradOpArgumentMapping( KernelSignature LayerNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("layer_norm_grad",
"layer_norm_grad", {"X", "Mean", "Variance", "Scale", "Bias", "Y@GRAD"},
{"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")}, {"epsilon", "begin_norm_axis", "is_test"},
{"epsilon", "begin_norm_axis", "is_test"}, {"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
} }
} // namespace phi } // namespace phi
......
...@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad", return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", GradVarName("Out")}, {"X", "Y", "Weight", "Out", "Out@GRAD"},
{}, {},
{GradVarName("X"), GradVarName("Y")}); {"X@GRAD", "Y@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace phi { namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
"lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
} }
} // namespace phi } // namespace phi
......
...@@ -19,9 +19,9 @@ namespace phi { ...@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogLossGradOpArgumentMapping( KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad", return KernelSignature("log_loss_grad",
{"Predicted", "Labels", GradVarName("Loss")}, {"Predicted", "Labels", "Loss@GRAD"},
{"epsilon"}, {"epsilon"},
{GradVarName("Predicted")}); {"Predicted@GRAD"});
} }
} // namespace phi } // namespace phi
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册