未验证 提交 7ee31a96 编写于 作者: C Chen Weihang 提交者: GitHub

[Perf] Optimize dygraph scheduling performance (#41696)

* split phi and fluid infermeta context

* resolve conflict

* fix type error

* optimize scheduling perf

* spec small vector size

* replace all grad var name

* fix test failed

* move init defalut signature

* polish details

* polish details

* fix no init bug

* init sig for tests

* add init sig for infer

* fix infrt error

* fix infrt failed

* fix kunlun error

* fix infrt failed
上级 b5d9c31c
......@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_lod(meta_tensor);
}
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
int index = compat_inputs_.size();
compat_inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
int index = compat_outputs_.size();
compat_outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = compat_inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
compat_inputs_.insert(compat_inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void CompatInferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs) {
int index = compat_outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
compat_outputs_.insert(compat_outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
return compat_inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&>
CompatInferMetaContext::OptionalInputAt(size_t idx) const {
const auto& input = compat_inputs_.at(idx);
return input.initialized()
? paddle::optional<const phi::MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
}
paddle::optional<const std::vector<const phi::MetaTensor*>>
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = compat_inputs_.at(start);
if (first.initialized()) {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(
paddle::none);
}
phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
auto& out = compat_outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
size_t start, size_t end) {
std::vector<phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& out = compat_outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
......@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
phi::InferMetaContext infer_meta_context(
CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);
auto kernels_map =
phi::KernelFactory::Instance().SelectKernelMap(signature.name);
if (kernels_map.size() == 0) {
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
const auto& args_def =
phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
const auto& attr_defs = args_def.attribute_defs();
// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name);
auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime()));
std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs;
inputs.reserve(input_var.size());
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
inputs;
for (const auto& in : input_var) {
inputs.push_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime()));
inputs.emplace_back(
std::move(CompatMetaTensor(in, ctx->IsRuntime())));
}
infer_meta_context.EmplaceBackInputs(std::move(inputs));
}
} else {
infer_meta_context.EmplaceBackInput({nullptr});
infer_meta_context.EmplaceBackInput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
}
}
VLOG(6) << "BuildInferMetaContext: Done inputs";
auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) {
// When attr is a vector_tensor or tensor, transform it to IntArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for IntArray
// and push it into attrs
......@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name));
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);
auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
......@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>();
......@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}
VLOG(6) << "BuildInferMetaContext: Done attrs";
for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name);
auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
output_var[0], ctx->IsRuntime()));
infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size());
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs;
for (const auto& out : output_var) {
if (ctx->IsRuntime()) {
if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue;
}
} else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue;
}
outputs.emplace_back(nullptr);
outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
}
infer_meta_context.EmplaceBackOutputs(std::move(outputs));
}
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
}
}
VLOG(6) << "BuildInferMetaContext: Done outputs";
return infer_meta_context;
}
......
......@@ -18,38 +18,24 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi {
class InferMetaContext;
} // namespace phi
namespace paddle {
namespace framework {
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor {
public:
explicit CompatMetaTensor(bool is_runtime)
: is_runtime_(is_runtime), initialized_(false) {}
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = default;
int64_t numel() const override;
......@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override;
bool initialized() const override { return initialized_; };
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
......@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor {
InferShapeVarPtr var_;
bool is_runtime_;
bool initialized_{true};
};
// Note: In order to avoid using shared_ptr to manage MetaTensor in
// InferMetaContext, inherit and implement InferMetaContext separately
// for compatibility with fluid, shared_ptr will cause significant decrease
// in scheduling performance
class CompatInferMetaContext : public phi::InferMetaContext {
public:
CompatInferMetaContext() = default;
explicit CompatInferMetaContext(phi::MetaConfig config)
: phi::InferMetaContext(config) {}
void EmplaceBackInput(CompatMetaTensor input);
void EmplaceBackOutput(CompatMetaTensor output);
void EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs);
const phi::MetaTensor& InputAt(size_t idx) const override;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(
size_t idx) const override;
std::vector<const phi::MetaTensor*> InputsBetween(size_t start,
size_t end) const override;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const override;
phi::MetaTensor* MutableOutputAt(size_t idx) override;
std::vector<phi::MetaTensor*> MutableOutputBetween(size_t start,
size_t end) override;
virtual ~CompatInferMetaContext() = default;
private:
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
compat_inputs_;
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
compat_outputs_;
};
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
} // namespace framework
} // namespace paddle
......@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr>
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
......
......@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
......
......@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
......@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
......
......@@ -945,19 +945,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
......@@ -1324,8 +1324,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< ", using_kernel_key:" << *kernel_type_.get();
auto try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
kernel_type_->library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel in static graph: " << type_
<< " is failed " << *kernel_type_.get();
......@@ -2113,10 +2113,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())(
arg_mapping_ctx);
if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
}
return (*arg_map_fn_)(arg_mapping_ctx);
}
Scope* OperatorWithKernel::PreparePhiData(
......
......@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
};
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle {
namespace framework {
......@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
const paddle::SmallVector<const char*>& GetInputArgsNames() override;
const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
......@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
private:
const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
paddle::SmallVector<const char*> input_names_;
paddle::SmallVector<const char*> output_names_;
paddle::SmallVector<const char*> attr_names_;
};
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
......@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey();
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
auto& in = op_proto_->inputs()[i];
auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name;
continue;
}
// If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue;
}
VLOG(6) << "Parse PhiKernel input: " << in_name;
input_names_.emplace_back(in_name);
input_names_.emplace_back(in_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel inputs: ";
std::copy(input_names_.begin(), input_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return input_names_;
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
for (int i = 0; i < op_proto_->outputs_size(); ++i) {
auto& out = op_proto_->outputs()[i];
auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name;
continue;
}
VLOG(6) << "Parse PhiKernel output: " << out_name;
output_names_.emplace_back(out_name);
output_names_.emplace_back(out_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel outputs: ";
std::copy(output_names_.begin(), output_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return output_names_;
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
for (int i = 0; i < op_proto_->attrs_size(); ++i) {
auto& attr = op_proto_->attrs()[i];
......@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") {
VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name;
continue;
}
if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name;
continue;
}
VLOG(6) << "Parse PhiKernel attribute: " << attr_name;
attr_names_.emplace_back(attr_name);
attr_names_.emplace_back(attr_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel attributes: ";
std::copy(attr_names_.begin(), attr_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return attr_names_;
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()),
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}
......@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
VLOG(10) << "Register `" << op_type << "` kernel signature:";
phi::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature()));
}
......
......@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
class KernelArgsNameMaker {
public:
virtual ~KernelArgsNameMaker() {}
virtual const paddle::SmallVector<std::string>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetAttrsArgsNames() = 0;
};
void InitDefaultKernelSignatureMap();
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/small_vector.h"
namespace paddle {
namespace framework {
......@@ -106,10 +108,10 @@ class InferShapeContext {
virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) const = 0;
virtual paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string &name) const = 0;
virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const = 0;
protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
......
......@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res;
paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
res;
auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_->end(),
......@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return res;
}
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res;
paddle::SmallVector<framework::InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
paddle::SmallVector<framework::InferShapeVarPtr,
phi::kOutputSmallVectorSize>
res;
auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_out_->end(),
......
......@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel);
namespace paddle {
namespace imperative {
static const phi::Kernel empty_kernel;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
......@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
ctx_(ctx),
kernel_type_(kernel_type),
func_(func),
dev_ctx_(dev_ctx) {}
dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel,
platform::DeviceContext* dev_ctx)
: op_(op),
......@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr),
dev_ctx_(dev_ctx),
run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_signature_(std::move(kernel_signature)),
pt_kernel_(pt_kernel) {}
template <typename VarType>
......@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx);
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
......@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
......@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key);
auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
......@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
return PreparedOp(op, ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
......@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_cpu_kernel, cpu_ctx);
return PreparedOp(op, ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
}
}
}
......@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl(
#endif
}
// TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
......
......@@ -154,7 +154,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
......@@ -206,7 +206,7 @@ class PreparedOp {
bool run_phi_kernel_{false};
bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_;
const phi::Kernel& pt_kernel_;
};
const inline framework::Attribute& GetAttr(
......@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext(
}
}
auto ins_vector = it->second;
auto& ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
......@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var = ins_vector[offset];
auto& var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) {
if (in_def.backend == phi::Backend::ALL_BACKEND) {
......
......@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope(
status_is_cloned_ = true;
} else {
paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
// TODO(wilber): we need to release memory occupied by weights.
scope_.reset(new paddle::framework::Scope());
status_is_cloned_ = false;
......
......@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init(
"The sub_scope should not be nullptr."));
} else {
paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
scope_.reset(new paddle::framework::Scope());
}
......
......@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs));
auto pg_ig_names = ctx->Outputs(kXGRAD);
std::vector<framework::InferShapeVarPtr> in_var_ptrs =
ctx->GetInputVarPtrs(kX);
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
ctx->GetOutputVarPtrs(kXGRAD);
auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(),
platform::errors::InvalidArgument(
"The size of Inputs(X) must be the same as "
......
......@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
context->ShareLoD("MultiLevelRois", "FpnRois");
}
if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs =
context->GetInputVarPtrs("MultiLevelRois");
std::vector<framework::InferShapeVarPtr> score_inputs =
context->GetInputVarPtrs("MultiLevelScores");
auto roi_inputs = context->GetInputVarPtrs("MultiLevelRois");
auto score_inputs = context->GetInputVarPtrs("MultiLevelScores");
for (size_t i = 0; i < roi_inputs.size(); ++i) {
framework::Variable *roi_var =
BOOST_GET(framework::Variable *, roi_inputs[i]);
......
......@@ -60,6 +60,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle {
namespace pybind {
......@@ -2027,7 +2028,8 @@ void BindImperative(py::module *m_ptr) {
*(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps()));
})
.def("_get_kernel_signature",
.def(
"_get_kernel_signature",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs) {
......@@ -2035,14 +2037,22 @@ void BindImperative(py::module *m_ptr) {
auto ins_map = ConvertToNameTensorMap(ins);
auto outs_map = ConvertToNameTensorMap(outs);
{
auto to_vector = [](paddle::SmallVector<std::string> &vec) {
auto input_to_vector =
[](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto output_to_vector =
[](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto attr_to_vector = [](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs);
auto kernelsig_ins = to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = to_vector(std::get<1>(ret.args));
auto kernelsig_outs = to_vector(std::get<2>(ret.args));
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args));
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args));
return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs);
}
......
......@@ -2941,6 +2941,8 @@ All parameter, weight, gradient are variables in Paddle.
framework::LoadOpMetaInfoAndRegisterOp(dso_name));
});
m.def("init_devices", []() { framework::InitDevices(); });
m.def("init_default_kernel_signatures",
[]() { framework::InitDefaultKernelSignatureMap(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM);
......
......@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include <glog/logging.h>
#include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/phi/kernels/declarations.h"
namespace infrt {
......@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels(
phi_kernel_desc.input_types.clear();
phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg =
args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs();
const paddle::SmallVector<phi::TensorArgDef, phi::kInputSmallVectorSize>&
input_arg = args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef, phi::kOutputSmallVectorSize>&
output_arg = args_def.output_defs();
for (auto tensor_arg : input_arg) {
phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
}
......
......@@ -91,6 +91,7 @@ using ValueVariantType =
std::vector<::phi::DenseTensor*>,
paddle::experimental::ScalarBase<::phi::DenseTensor>,
paddle::experimental::IntArrayBase<::phi::DenseTensor>,
std::vector<const ::phi::MetaTensor*>,
std::vector<::phi::MetaTensor*>,
::phi::MetaConfig,
paddle::experimental::Backend,
......
......@@ -19,45 +19,33 @@ limitations under the License. */
#include <tuple>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
namespace phi {
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
inline std::string GradVarName(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;
using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
paddle::SmallVector<const char*>,
paddle::SmallVector<const char*>>;
struct KernelSignature {
std::string name;
const char* name;
KernelArgsTuple args;
KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
KernelSignature(const char* kernel_name,
paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<const char*>&& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const char* kernel_name,
const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<const char*>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
// TODO(chenweihang): add assign constructor to solve windows compile
......
......@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
}
}
std::string TransToPhiKernelName(const std::string& fluid_op_name) {
const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
}
......
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace phi {
std::string TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& phi_kernel_name);
Backend TransToPhiBackend(const phi::Place& place);
......
......@@ -26,6 +26,8 @@ limitations under the License. */
namespace phi {
const static std::string deprecated_kernel_name = "deprecated"; // NOLINT
const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op
......@@ -134,9 +136,9 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
}
std::string GetBaseKernelName(const std::string& op_type) const {
const std::string& GetBaseKernelName(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return "deprecated";
return deprecated_kernel_name;
}
auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) {
......@@ -150,7 +152,7 @@ class OpUtilsMap {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
......
......@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) {
config_ = std::move(config);
}
void InferMetaContext::EmplaceBackInput(
std::shared_ptr<phi::MetaTensor> input) {
void InferMetaContext::EmplaceBackInput(MetaTensor input) {
int index = inputs_.size();
inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void InferMetaContext::EmplaceBackOutput(
std::shared_ptr<phi::MetaTensor> output) {
void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
int index = outputs_.size();
outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
......@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
}
void InferMetaContext::EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs) {
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
......@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end()));
}
void InferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs) {
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs) {
int index = outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
......@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const {
const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
return inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt(
paddle::optional<const MetaTensor&> InferMetaContext::OptionalInputAt(
size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const phi::MetaTensor&>{static_cast<
const phi::MetaTensor&>(*input)}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
return input.initialized()
? paddle::optional<const MetaTensor&>{input}
: paddle::optional<const MetaTensor&>{paddle::none};
}
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor*> result;
std::vector<const MetaTensor*> InferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
......@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start);
if (first) {
if (first.initialized()) {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const MetaTensor*>>(result);
......@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
auto& out = outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
......@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
std::vector<MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(outputs_.at(i).get());
auto& out = outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
......
......@@ -37,28 +37,28 @@ class InferMetaContext {
explicit InferMetaContext(MetaConfig config) : config_(config) {}
void SetMetaConfig(MetaConfig config);
void EmplaceBackInput(std::shared_ptr<phi::MetaTensor> input);
void EmplaceBackOutput(std::shared_ptr<phi::MetaTensor> output);
const MetaConfig& GetMetaConfig() const;
void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr);
void EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs);
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs);
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs);
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual const MetaTensor& InputAt(size_t idx) const;
virtual paddle::optional<const MetaTensor&> OptionalInputAt(size_t idx) const;
const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
virtual std::vector<const MetaTensor*> InputsBetween(size_t start,
size_t end) const;
virtual paddle::optional<const std::vector<const MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);
virtual MetaTensor* MutableOutputAt(size_t idx);
virtual std::vector<MetaTensor*> MutableOutputBetween(size_t start,
size_t end);
template <typename AttrType>
AttrType AttrAt(size_t idx) {
......@@ -73,19 +73,24 @@ class InferMetaContext {
}
}
private:
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual ~InferMetaContext() = default;
protected:
MetaConfig config_;
// NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor
// objects are all created in each round, so we have to use smart pointer
// here, maybe we can implemented a new InferMetaContext and a series utils
// specifically for fluid to avoid using shared_ptr
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs_;
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs_;
paddle::SmallVector<paddle::any> attrs_;
paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_;
paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
paddle::SmallVector<std::pair<int, int>, phi::kOutputSmallVectorSize>
output_range_;
private:
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs_;
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs_;
};
#define PD_INFER_META(...) \
......@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> {
struct InferMetaFnCallHelper<const std::vector<const MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
......@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor*> arg =
std::vector<const MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
......
......@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) {
void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) {
input_range_[idx] = range;
input_range_[idx] = std::move(range);
} else if (idx == input_range_.size()) {
input_range_.emplace_back(range);
} else {
......@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) {
output_range_[idx] = range;
output_range_[idx] = std::move(range);
} else if (idx == output_range_.size()) {
output_range_.emplace_back(range);
} else {
......
......@@ -19,6 +19,8 @@
namespace phi {
const static Kernel empty_kernel; // NOLINT
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---|
......@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory;
}
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
return Kernel();
return empty_kernel;
}
auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) {
return Kernel();
return empty_kernel;
}
return kernel_iter->second;
}
......@@ -59,7 +61,7 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second;
}
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
bool KernelFactory::HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
......@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype));
}
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
return iter->second.cbegin()->second.args_def();
}
// print kernel info with json format:
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
......
......@@ -151,30 +151,38 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index));
}
const paddle::SmallVector<TensorArgDef>& input_defs() const {
const paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs()
const {
return input_defs_;
}
const paddle::SmallVector<TensorArgDef>& output_defs() const {
const paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs()
const {
return output_defs_;
}
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const {
const paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>&
attribute_defs() const {
return attribute_defs_;
}
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; }
paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
return input_defs_;
}
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; }
paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
return output_defs_;
}
paddle::SmallVector<AttributeArgDef>& attribute_defs() {
paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>& attribute_defs() {
return attribute_defs_;
}
private:
paddle::SmallVector<TensorArgDef> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}};
paddle::SmallVector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
{}};
};
class Kernel {
......@@ -209,7 +217,7 @@ class Kernel {
TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }
bool IsValid() { return fn_ != nullptr; }
bool IsValid() const { return fn_ != nullptr; }
private:
KernelFn fn_{nullptr};
......@@ -246,14 +254,17 @@ class KernelFactory {
DataLayout layout,
DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name,
bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name,
const Kernel& SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const;
private:
KernelFactory() = default;
......
......@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
}
}
bool MetaTensor::initialized() const { return tensor_ != nullptr; }
} // namespace phi
......@@ -45,10 +45,10 @@ class MetaTensor {
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete;
MetaTensor& operator=(MetaTensor&&) = delete;
MetaTensor& operator=(MetaTensor&&) = default;
MetaTensor(const MetaTensor&) = default;
MetaTensor& operator=(const MetaTensor&) = default;
virtual ~MetaTensor() = default;
......@@ -64,6 +64,8 @@ class MetaTensor {
virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor);
virtual bool initialized() const;
private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
......
......@@ -22,7 +22,7 @@ class Kernel;
class KernelKey;
class KernelArgsDef;
class KernelContext;
class KernelSignature;
struct KernelSignature;
class ArgumentMappingContext;
class InferMetaContext;
......@@ -35,4 +35,9 @@ using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx);
// Global SmallVector size setting
constexpr size_t kInputSmallVectorSize = 10U;
constexpr size_t kAttrSmallVectorSize = 10U;
constexpr size_t kOutputSmallVectorSize = 5U;
} // namespace phi
......@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x);
}
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) {
PADDLE_ENFORCE_GT(outputs_grad.size(),
1,
......@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
}
}
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x,
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad) {
PADDLE_ENFORCE_EQ(
......
......@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive,
MetaTensor* dx);
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad);
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x,
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad);
......
......@@ -21,7 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) {
std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors) {
std::vector<DDim> dims;
dims.reserve(tensors.size());
for (const MetaTensor* tensor : tensors) {
......@@ -279,7 +280,7 @@ void AdamwInferMeta(const MetaTensor& param,
master_param_outs);
}
void AddNInferMeta(const std::vector<MetaTensor*>& x,
void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config) {
auto N = x.size();
......@@ -642,7 +643,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) {
int target_rank = 0;
const auto& input_dims = GetMetaTensorsDim(x);
......@@ -696,7 +697,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
}
}
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config) {
......@@ -1488,7 +1489,7 @@ void InterpolateInferMeta(
}
}
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......@@ -1551,7 +1552,8 @@ void MomentumInferMeta(const MetaTensor& param,
}
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
const size_t inputs_num = inputs_dims.size();
......@@ -1624,7 +1626,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0));
}
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out) {
PADDLE_ENFORCE_NE(
......@@ -1803,8 +1805,8 @@ void RmspropInferMeta(const MetaTensor& param,
}
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
......@@ -1910,7 +1912,7 @@ void SgdInferMeta(const MetaTensor& param,
param_out->set_dtype(param.dtype());
}
void StackInferMeta(const std::vector<MetaTensor*>& x,
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out) {
PADDLE_ENFORCE_GT(x.size(),
......@@ -1956,7 +1958,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) {
out[i]->share_meta(*x[i]);
......
......@@ -35,7 +35,8 @@ namespace phi {
//
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors);
std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors);
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
......@@ -117,7 +118,7 @@ void AdamwInferMeta(const MetaTensor& param,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
void AddNInferMeta(const std::vector<MetaTensor*>& x,
void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
......@@ -173,10 +174,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out);
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config = MetaConfig());
......@@ -227,7 +228,7 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
void MomentumInferMeta(const MetaTensor& param,
......@@ -245,9 +246,10 @@ void MomentumInferMeta(const MetaTensor& param,
MetaTensor* velocity_out,
MetaTensor* master_param_out);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);
......@@ -276,8 +278,8 @@ void RmspropInferMeta(const MetaTensor& param,
MetaTensor* mean_grad_out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
......@@ -300,11 +302,11 @@ void SgdInferMeta(const MetaTensor& param,
MetaTensor* param_out,
MetaTensor* master_param_out);
void StackInferMeta(const std::vector<MetaTensor*>& x,
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out);
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out);
void WarpctcInferMeta(const MetaTensor& logits,
......
......@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx,
const Scalar& axis) {
std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr;
std::vector<const MetaTensor*> meta_x_ptr;
for (const auto* t : x) {
meta_x.emplace_back(*t);
meta_x_ptr.push_back(&meta_x.back());
......
......@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("abs_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature AbsDoubleGradOpArgumentMapping(
......
......@@ -19,26 +19,22 @@ namespace phi {
#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
{"X", GradVarName("Out")}, \
{attrs}, \
{GradVarName("X")}); \
return KernelSignature( \
op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
{"Out", GradVarName("Out")}, \
{attrs}, \
{GradVarName("X")}); \
return KernelSignature( \
op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \
op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define comma ,
......@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"logit_grad", {"X", GradVarName("Out")}, {"eps"}, {GradVarName("X")});
return KernelSignature("logit_grad", {"X", "Out@GRAD"}, {"eps"}, {"X@GRAD"});
}
KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu_grad",
{"X", "Out", GradVarName("Out")},
{"alpha"},
{GradVarName("X")});
return KernelSignature(
"elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"});
}
KernelSignature EluDoubleGradOpArgumentMapping(
......@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_grad",
{"X", GradVarName("Out")},
{"FactorTensor"},
{GradVarName("X")});
return KernelSignature(
"pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"});
} else {
return KernelSignature(
"pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")});
"pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"});
}
}
......
......@@ -19,7 +19,7 @@
namespace phi {
KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> in_names = {"Param",
paddle::SmallVector<const char*> in_names = {"Param",
"Grad",
"LearningRate",
"Moment1",
......@@ -28,13 +28,13 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
"Beta2Pow",
"MasterParam",
"SkipUpdate"};
paddle::SmallVector<std::string> out_names = {"ParamOut",
paddle::SmallVector<const char*> out_names = {"ParamOut",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"};
paddle::SmallVector<std::string> attr_names;
paddle::SmallVector<const char*> attr_names;
attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor"
: "beta1");
......
......@@ -19,7 +19,7 @@
namespace phi {
KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> in_names = {"Param",
paddle::SmallVector<const char*> in_names = {"Param",
"Grad",
"LearningRate",
"Moment1",
......@@ -28,13 +28,13 @@ KernelSignature AdamwOpArgumentMapping(const ArgumentMappingContext& ctx) {
"Beta2Pow",
"MasterParam",
"SkipUpdate"};
paddle::SmallVector<std::string> out_names = {"ParamOut",
paddle::SmallVector<const char*> out_names = {"ParamOut",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"};
paddle::SmallVector<std::string> attr_names;
paddle::SmallVector<const char*> attr_names;
attr_names.emplace_back(ctx.HasInput("Beta1Tensor") ? "Beta1Tensor"
: "beta1");
......
......@@ -17,11 +17,10 @@
namespace phi {
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"addmm_grad",
{"Input", "X", "Y", GradVarName("Out")},
return KernelSignature("addmm_grad",
{"Input", "X", "Y", "Out@GRAD"},
{"Alpha", "Beta"},
{GradVarName("Input"), GradVarName("X"), GradVarName("Y")});
{"Input@GRAD", "X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature ArgsortGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("argsort_grad",
{"Indices", "X", GradVarName("Out")},
{"Indices", "X", "Out@GRAD"},
{"axis", "descending"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@
namespace phi {
KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("atan2_grad",
{"X1", "X2", GradVarName("Out")},
{},
{GradVarName("X1"), GradVarName("X2")});
return KernelSignature(
"atan2_grad", {"X1", "X2", "Out@GRAD"}, {}, {"X1@GRAD", "X2@GRAD"});
}
} // namespace phi
......
......@@ -57,8 +57,7 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"batch_norm_grad",
return KernelSignature("batch_norm_grad",
{
"X",
"Scale",
......@@ -68,7 +67,7 @@ KernelSignature BatchNormGradOpArgumentMapping(
"SavedMean",
"SavedVariance",
"ReserveSpace",
GradVarName("Y"),
"Y@GRAD",
},
{"momentum",
"epsilon",
......@@ -77,7 +76,7 @@ KernelSignature BatchNormGradOpArgumentMapping(
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
KernelSignature BatchNormGradGradOpArgumentMapping(
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature BCELossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bce_loss_grad",
{"X", "Label", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"bce_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping(
KernelSignature BilinearTensorProductGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_tensor_product_grad",
{"X", "Y", "Weight", GradVarName("Out")},
{"X", "Y", "Weight", "Out@GRAD"},
{},
{GradVarName("X"),
GradVarName("Y"),
GradVarName("Weight"),
GradVarName("Bias")});
{"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"});
}
} // namespace phi
......
......@@ -19,7 +19,7 @@ namespace phi {
KernelSignature BroadcastTensorsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"broadcast_tensors_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
"broadcast_tensors_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CholeskyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_grad",
{"Out", GradVarName("Out")},
{"upper"},
{GradVarName("X")});
return KernelSignature(
"cholesky_grad", {"Out", "Out@GRAD"}, {"upper"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature CholeskySolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_solve_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"X", "Y", "Out", "Out@GRAD"},
{"upper"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,7 +18,7 @@
namespace phi {
KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> attr_names;
paddle::SmallVector<std::string, kAttrSmallVectorSize> attr_names;
attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min");
attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max");
if (ctx.IsDenseTensorInput("X")) {
......@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "Max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"});
}
} else {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "Max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"});
}
}
}
......
......@@ -17,13 +17,11 @@
namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"real_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat_grad",
{"X", {GradVarName("Out")}},
{"AxisTensor"},
{{GradVarName("X")}});
return KernelSignature(
"concat_grad", {"X", {"Out@GRAD"}}, {"AxisTensor"}, {{"X@GRAD"}});
}
return KernelSignature("concat_grad",
{"X", {GradVarName("Out")}},
{"axis"},
{{GradVarName("X")}});
return KernelSignature(
"concat_grad", {"X", {"Out@GRAD"}}, {"axis"}, {{"X@GRAD"}});
}
} // namespace phi
......
......@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dDoubleGradOpArgumentMapping(
......
......@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv3dDoubleGradOpArgumentMapping(
......
......@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping(
KernelSignature Conv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping(
......@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping(
KernelSignature Conv3dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("conv3d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
......@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
} // namespace phi
......
......@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cross_grad",
{"X", "Y", GradVarName("Out")},
{"dim"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"cross_grad", {"X", "Y", "Out@GRAD"}, {"dim"}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CumprodGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cumprod_grad",
{"X", "Out", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
return KernelSignature(
"cumprod_grad", {"X", "Out", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"deformable_conv_grad",
{"Input", "Offset", "Filter", "Mask", GradVarName("Output")},
{"Input", "Offset", "Filter", "Mask", "Output@GRAD"},
{"strides",
"paddings",
"dilations",
"deformable_groups",
"groups",
"im2col_step"},
{GradVarName("Input"),
GradVarName("Offset"),
GradVarName("Filter"),
GradVarName("Mask")});
{"Input@GRAD", "Offset@GRAD", "Filter@GRAD", "Mask@GRAD"});
}
} // namespace phi
......
......@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping(
KernelSignature DepthwiseConv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping(
"workspace_size_MB",
"exhaustive_search",
"fuse_relu_before_depthwise_conv"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping(
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature DeterminantGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("determinant_grad",
{"Input", "Out", GradVarName("Out")},
{},
{GradVarName("Input")});
return KernelSignature(
"determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"});
}
} // namespace phi
......
......@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")});
"diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature DiagonalGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("diagonal_grad",
{"Input", GradVarName("Out")},
{"Input", "Out@GRAD"},
{"offset", "axis1", "axis2"},
{GradVarName("Input")});
{"Input@GRAD"});
}
} // namespace phi
......
......@@ -18,8 +18,7 @@ namespace phi {
KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi {
KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dist_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"p"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi {
KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DropoutGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_grad",
{"Mask", GradVarName("Out")},
{"Mask", "Out@GRAD"},
{"dropout_prob", "is_test", "dropout_implementation"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,13 +17,11 @@
namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigh_grad",
{"Eigenvalues",
"Eigenvectors",
GradVarName("Eigenvalues"),
GradVarName("Eigenvectors")},
return KernelSignature(
"eigh_grad",
{"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
{},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping(
KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
......@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("subtract_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
......@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
KernelSignature ElementwiseDivGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("divide_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"X", "Y", "Out", "Out@GRAD"},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
......@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"},
{"axis"},
{GradVarName("Y"), "DOut", "DDOut"});
{"Y@GRAD", "DOut", "DDOut"});
}
KernelSignature ElementwiseMulGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseFMaxOpArgumentMapping(
......@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmax_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
......@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
return KernelSignature("multiply_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"axis"},
{GradVarName("X"), GradVarName("Y"), "DDOut"});
{"X@GRAD", "Y@GRAD", "DDOut"});
}
KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
......@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("maximum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("minimum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad",
{"X", "Y", GradVarName("Out")},
{"X", "Y", "Out@GRAD"},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping(
if (ctx.IsDenseTensorInput("W")) {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
} else {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
}
} else {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("sparse_weight_embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
}
}
}
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("erfinv_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandAsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as_grad",
{"X", GradVarName("Out")},
{"target_shape"},
{GradVarName("X")});
return KernelSignature(
"expand_as_grad", {"X", "Out@GRAD"}, {"target_shape"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -28,20 +28,14 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"Shape"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"expand_shapes_tensor"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"expand_shapes_tensor"}, {"X@GRAD"});
} else {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"shape"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
}
}
......
......@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
"flatten_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("frobenius_norm_grad",
{"X", "Out", GradVarName("Out")},
{"X", "Out", "Out@GRAD"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,25 +17,23 @@
namespace phi {
KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gather_nd_grad",
{"X", "Index", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_grad",
{"Ids", "Updates", GradVarName("Out")},
{"Ids", "Updates", "Out@GRAD"},
{"overwrite"},
{GradVarName("X"), GradVarName("Updates")});
{"X@GRAD", "Updates@GRAD"});
}
KernelSignature ScatterNdAddGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_nd_add_grad",
{"Index", "Updates", GradVarName("Out")},
{"Index", "Updates", "Out@GRAD"},
{},
{GradVarName("X"), GradVarName("Updates")});
{"X@GRAD", "Updates@GRAD"});
}
} // namespace phi
......
......@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"X", "Index", "Out@GRAD"},
{"Axis", "overwrite"},
{GradVarName("X")});
{"X@GRAD"});
} else {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"X", "Index", "Out@GRAD"},
{"axis", "overwrite"},
{GradVarName("X")});
{"X@GRAD"});
}
}
......
......@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu_grad",
{"X", GradVarName("Out")},
{"approximate"},
{GradVarName("X")});
return KernelSignature(
"gelu_grad", {"X", "Out@GRAD"}, {"approximate"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")},
{"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"pool_type"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping(
KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad",
{"X", "Grid", GradVarName("Output")},
{"X", "Grid", "Output@GRAD"},
{"mode", "padding_mode", "align_corners"},
{GradVarName("X"), GradVarName("Grid")});
{"X@GRAD", "Grid@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("gumbel_softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
return KernelSignature(
"gumbel_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -32,9 +32,8 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping(
KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad",
if (ctx.IsDenseTensorOutput("W@GRAD")) {
return KernelSignature("hierarchical_sigmoid_grad",
{"X",
"W",
"Label",
......@@ -42,7 +41,7 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
"PathCode",
"Bias",
"PreOut",
GradVarName("Out")},
"Out@GRAD"},
{"num_classes",
"remote_prefetch",
"trainer_id",
......@@ -50,10 +49,9 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad_sr",
{"X@GRAD", "W@GRAD", "Bias@GRAD"});
} else if (ctx.IsSelectedRowsOutput("W@GRAD")) {
return KernelSignature("hierarchical_sigmoid_grad_sr",
{"X",
"W",
"Label",
......@@ -61,7 +59,7 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
"PathCode",
"Bias",
"PreOut",
GradVarName("Out")},
"Out@GRAD"},
{"num_classes",
"remote_prefetch",
"trainer_id",
......@@ -69,7 +67,7 @@ KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
{"X@GRAD", "W@GRAD", "Bias@GRAD"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
......
......@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature HuberLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("huber_loss_grad",
{"Residual", GradVarName("Out")},
{"Residual", "Out@GRAD"},
{"delta"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad",
{"X", "Index", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad",
{"X", "Index", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
return KernelSignature(
"index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -92,9 +92,8 @@ KernelSignature BicubicInterpOpArgumentMapping(
KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
return KernelSignature("bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
......@@ -103,14 +102,13 @@ KernelSignature BilinearInterpGradOpArgumentMapping(
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
{"X@GRAD"});
}
KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
return KernelSignature("nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
......@@ -119,13 +117,12 @@ KernelSignature NearestInterpGradOpArgumentMapping(
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
{"X@GRAD"});
}
KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
return KernelSignature("trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
......@@ -134,14 +131,13 @@ KernelSignature TrilinearInterpGradOpArgumentMapping(
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
{"X@GRAD"});
}
KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
return KernelSignature("linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
......@@ -150,14 +146,13 @@ KernelSignature LinearInterpGradOpArgumentMapping(
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
{"X@GRAD"});
}
KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
return KernelSignature("bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
......@@ -166,7 +161,7 @@ KernelSignature BicubicInterpGradOpArgumentMapping(
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KLDivLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("kldiv_loss_grad",
{"X", "Target", GradVarName("Loss")},
{"X", "Target", "Loss@GRAD"},
{"reduction"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@
namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KthvalueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("kthvalue_grad",
{"X", "Indices", GradVarName("Out")},
{"X", "Indices", "Out@GRAD"},
{"k", "axis", "keepdim"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping(
KernelSignature LabelSmoothGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("label_smooth_grad",
{GradVarName("Out")},
{"epsilon"},
{GradVarName("X")});
return KernelSignature(
"label_smooth_grad", {"Out@GRAD"}, {"epsilon"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LayerNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"layer_norm_grad",
{"X", "Scale", "Bias", "Mean", "Variance", GradVarName("Y")},
return KernelSignature("layer_norm_grad",
{"X", "Scale", "Bias", "Mean", "Variance", "Y@GRAD"},
{"epsilon", "begin_norm_axis", "is_test"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
} // namespace phi
......
......@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", GradVarName("Out")},
{"X", "Y", "Weight", "Out", "Out@GRAD"},
{},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad",
{"Predicted", "Labels", GradVarName("Loss")},
{"Predicted", "Labels", "Loss@GRAD"},
{"epsilon"},
{GradVarName("Predicted")});
{"Predicted@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature LogSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
return KernelSignature(
"log_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogsumexpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logsumexp_grad",
{"X", "Out", GradVarName("Out")},
{"X", "Out", "Out@GRAD"},
{"axis", "keepdim", "reduce_all"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -23,10 +23,8 @@ KernelSignature MaskedSelectOpArgumentMapping(
KernelSignature MaskedSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select_grad",
{"X", "Mask", GradVarName("Y")},
{},
{GradVarName("X")});
return KernelSignature(
"masked_select_grad", {"X", "Mask", "Y@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,14 +19,14 @@ namespace phi {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad",
{"X", "Y", GradVarName("Out")},
{"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y", "use_addto"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
} else {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册