未验证 提交 ec1d2a16 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Optimize dygraph scheduling performance (#42010)

* [Phi] Support setting size of vector<Tensor> for out in yaml (#41576)

* support setting vector out size in yaml

* support setting size of vector<tensor> for out in yaml

* resolve conflict
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
上级 e4cb897e
......@@ -235,7 +235,7 @@ def ParseYamlReturns(string):
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret = returns[i]
ret = returns[i].split("{")[0].strip()
ret_name = ""
if "(" in ret and ")" in ret:
......
......@@ -308,10 +308,100 @@ void CompatMetaTensor::share_meta(const MetaTensor& meta_tensor) {
share_lod(meta_tensor);
}
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
void CompatInferMetaContext::EmplaceBackInput(CompatMetaTensor input) {
int index = compat_inputs_.size();
compat_inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackOutput(CompatMetaTensor output) {
int index = compat_outputs_.size();
compat_outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void CompatInferMetaContext::EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = compat_inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
compat_inputs_.insert(compat_inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void CompatInferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs) {
int index = compat_outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
compat_outputs_.insert(compat_outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
const phi::MetaTensor& CompatInferMetaContext::InputAt(size_t idx) const {
return compat_inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&>
CompatInferMetaContext::OptionalInputAt(size_t idx) const {
const auto& input = compat_inputs_.at(idx);
return input.initialized()
? paddle::optional<const phi::MetaTensor&>{input}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<const phi::MetaTensor*> CompatInferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
}
paddle::optional<const std::vector<const phi::MetaTensor*>>
CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = compat_inputs_.at(start);
if (first.initialized()) {
std::vector<const phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = compat_inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(result);
}
return paddle::optional<const std::vector<const phi::MetaTensor*>>(
paddle::none);
}
phi::MetaTensor* CompatInferMetaContext::MutableOutputAt(size_t idx) {
auto& out = compat_outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
size_t start, size_t end) {
std::vector<phi::MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& out = compat_outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
......@@ -321,52 +411,47 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
// 2. build infermeta context
phi::InferMetaContext infer_meta_context(
CompatInferMetaContext infer_meta_context(
{ctx->IsRuntime(), ctx->IsRunMKLDNNKernel()});
auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);
auto kernels_map =
phi::KernelFactory::Instance().SelectKernelMap(signature.name);
if (kernels_map.size() == 0) {
PADDLE_THROW(
platform::errors::Unimplemented("Not find `%s` kernels when construct "
"InferMetaContext.",
signature.name));
}
auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
const auto& args_def =
phi::KernelFactory::Instance().GetFirstKernelArgsDef(signature.name);
const auto& attr_defs = args_def.attribute_defs();
// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name);
auto input_var = std::move(ctx->GetInputVarPtrs(in_name));
if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime()));
std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime())));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs;
inputs.reserve(input_var.size());
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
inputs;
for (const auto& in : input_var) {
inputs.push_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime()));
inputs.emplace_back(
std::move(CompatMetaTensor(in, ctx->IsRuntime())));
}
infer_meta_context.EmplaceBackInputs(std::move(inputs));
}
} else {
infer_meta_context.EmplaceBackInput({nullptr});
infer_meta_context.EmplaceBackInput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
}
}
VLOG(6) << "BuildInferMetaContext: Done inputs";
auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) {
// When attr is a vector_tensor or tensor, transform it to IntArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for IntArray
// and push it into attrs
......@@ -456,7 +541,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name));
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);
auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
......@@ -581,7 +666,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>();
......@@ -596,36 +681,41 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}
VLOG(6) << "BuildInferMetaContext: Done attrs";
for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name, true)) {
auto output_var = ctx->GetOutputVarPtrs(out_name);
auto output_var = std::move(ctx->GetOutputVarPtrs(out_name));
if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
output_var[0], ctx->IsRuntime()));
infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime())));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size());
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs;
for (const auto& out : output_var) {
if (ctx->IsRuntime()) {
if (BOOST_GET_CONST(Variable*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue;
}
} else if (BOOST_GET_CONST(VarDesc*, out)) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
std::move(CompatMetaTensor(out, ctx->IsRuntime())));
continue;
}
outputs.emplace_back(nullptr);
outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime())));
}
infer_meta_context.EmplaceBackOutputs(std::move(outputs));
}
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
infer_meta_context.EmplaceBackOutput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
}
}
VLOG(6) << "BuildInferMetaContext: Done outputs";
return infer_meta_context;
}
......
......@@ -18,38 +18,24 @@ limitations under the License. */
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
namespace phi {
class InferMetaContext;
} // namespace phi
namespace paddle {
namespace framework {
phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public phi::MetaTensor {
public:
explicit CompatMetaTensor(bool is_runtime)
: is_runtime_(is_runtime), initialized_(false) {}
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}
CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = default;
int64_t numel() const override;
......@@ -71,6 +57,8 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override;
bool initialized() const override { return initialized_; };
private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
......@@ -95,7 +83,62 @@ class CompatMetaTensor : public phi::MetaTensor {
InferShapeVarPtr var_;
bool is_runtime_;
bool initialized_{true};
};
// Note: In order to avoid using shared_ptr to manage MetaTensor in
// InferMetaContext, inherit and implement InferMetaContext separately
// for compatibility with fluid, shared_ptr will cause significant decrease
// in scheduling performance
class CompatInferMetaContext : public phi::InferMetaContext {
public:
CompatInferMetaContext() = default;
explicit CompatInferMetaContext(phi::MetaConfig config)
: phi::InferMetaContext(config) {}
void EmplaceBackInput(CompatMetaTensor input);
void EmplaceBackOutput(CompatMetaTensor output);
void EmplaceBackInputs(
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
outputs);
const phi::MetaTensor& InputAt(size_t idx) const override;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(
size_t idx) const override;
std::vector<const phi::MetaTensor*> InputsBetween(size_t start,
size_t end) const override;
paddle::optional<const std::vector<const phi::MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const override;
phi::MetaTensor* MutableOutputAt(size_t idx) override;
std::vector<phi::MetaTensor*> MutableOutputBetween(size_t start,
size_t end) override;
virtual ~CompatInferMetaContext() = default;
private:
paddle::SmallVector<CompatMetaTensor, phi::kInputSmallVectorSize>
compat_inputs_;
paddle::SmallVector<CompatMetaTensor, phi::kOutputSmallVectorSize>
compat_outputs_;
};
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);
#define DECLARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}
} // namespace framework
} // namespace paddle
......@@ -328,20 +328,21 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr>
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
InterpretercoreInferShapeContext::GetOutputVarPtrs(
const std::string& name) const {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
......
......@@ -90,11 +90,11 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override;
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override;
DDim GetInputDim(const std::string& name) const override;
......
......@@ -202,10 +202,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
......@@ -214,10 +214,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
......
......@@ -947,19 +947,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
......@@ -1326,8 +1326,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
<< ", using_kernel_key:" << *kernel_type_.get();
auto try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
kernel_type_->library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel in static graph: " << type_
<< " is failed " << *kernel_type_.get();
......@@ -2115,10 +2115,12 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())(
arg_mapping_ctx);
if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
}
return (*arg_map_fn_)(arg_mapping_ctx);
}
Scope* OperatorWithKernel::PreparePhiData(
......
......@@ -701,6 +701,7 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_kp_kernel = false;
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
};
extern bool OpSupportGPU(const std::string& op_type);
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle {
namespace framework {
......@@ -40,9 +41,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
const paddle::SmallVector<const char*>& GetInputArgsNames() override;
const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
......@@ -52,9 +53,9 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
private:
const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
paddle::SmallVector<const char*> input_names_;
paddle::SmallVector<const char*> output_names_;
paddle::SmallVector<const char*> attr_names_;
};
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
......@@ -102,7 +103,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_xpu_place(expected_kernel_key.place_) ||
paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -111,7 +112,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -120,7 +121,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -129,7 +130,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
#ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -139,7 +140,7 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
if (platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType()
<< " kernel: " << op.Type()
<< ", phipected_kernel_key:" << expected_kernel_key
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
return phi::KernelKey(phi::Backend::CPU, kernel_key.layout(),
kernel_key.dtype());
......@@ -148,45 +149,52 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
return phi::KernelKey();
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
auto& in = op_proto_->inputs()[i];
auto& in_name = in.name();
if ((in.has_extra() && in.extra()) || (in.has_quant() && in.quant())) {
VLOG(6) << "Parse PhiKernel input: skip extra & quant input - "
<< in_name;
continue;
}
// If contains dispensable input, we should override the
// OpArgumentMapping method self in phi/ops/compat dir
if (in.has_dispensable() && in.dispensable()) {
VLOG(6) << "Parse PhiKernel input: skip dispensable input - " << in_name;
continue;
}
VLOG(6) << "Parse PhiKernel input: " << in_name;
input_names_.emplace_back(in_name);
input_names_.emplace_back(in_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel inputs: ";
std::copy(input_names_.begin(), input_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return input_names_;
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetOutputArgsNames() {
for (int i = 0; i < op_proto_->outputs_size(); ++i) {
auto& out = op_proto_->outputs()[i];
auto& out_name = out.name();
if ((out.has_extra() && out.extra()) || (out.has_quant() && out.quant())) {
VLOG(6) << "Parse PhiKernel output: skip extra & quant output - "
<< out_name;
continue;
}
VLOG(6) << "Parse PhiKernel output: " << out_name;
output_names_.emplace_back(out_name);
output_names_.emplace_back(out_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel outputs: ";
std::copy(output_names_.begin(), output_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return output_names_;
}
const paddle::SmallVector<std::string>&
const paddle::SmallVector<const char*>&
KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
for (int i = 0; i < op_proto_->attrs_size(); ++i) {
auto& attr = op_proto_->attrs()[i];
......@@ -195,25 +203,26 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
attr_name == "op_role" || attr_name == "op_role_var" ||
attr_name == "op_namescope" || attr_name == "op_callstack" ||
attr_name == "op_device") {
VLOG(6) << "Parse PhiKernel attribute: skip needless attr - "
<< attr_name;
continue;
}
if ((attr.has_extra() && attr.extra()) ||
(attr.has_quant() && attr.quant())) {
VLOG(6) << "Parse PhiKernel attribute: skip extra & quant attr - "
<< attr_name;
continue;
}
VLOG(6) << "Parse PhiKernel attribute: " << attr_name;
attr_names_.emplace_back(attr_name);
attr_names_.emplace_back(attr_name.c_str());
}
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "PhiKernel attributes: ";
std::copy(attr_names_.begin(), attr_names_.end(),
std::ostream_iterator<const char*>(sout, ", "));
VLOG(10) << sout.str();
}
return attr_names_;
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()),
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}
......@@ -228,7 +237,7 @@ void InitDefaultKernelSignatureMap() {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
VLOG(10) << "Register `" << op_type << "` kernel signature:";
phi::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature()));
}
......
......@@ -55,9 +55,9 @@ phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
class KernelArgsNameMaker {
public:
virtual ~KernelArgsNameMaker() {}
virtual const paddle::SmallVector<std::string>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetInputArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetOutputArgsNames() = 0;
virtual const paddle::SmallVector<const char*>& GetAttrsArgsNames() = 0;
};
void InitDefaultKernelSignatureMap();
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/small_vector.h"
namespace paddle {
namespace framework {
......@@ -106,10 +108,10 @@ class InferShapeContext {
virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) const = 0;
virtual paddle::SmallVector<InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string &name) const = 0;
virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string &name) const = 0;
protected:
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
......
......@@ -235,9 +235,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res;
paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
GetInputVarPtrs(const std::string& name) const override {
paddle::SmallVector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
res;
auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_->end(),
......@@ -248,9 +249,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return res;
}
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) const override {
std::vector<framework::InferShapeVarPtr> res;
paddle::SmallVector<framework::InferShapeVarPtr, phi::kOutputSmallVectorSize>
GetOutputVarPtrs(const std::string& name) const override {
paddle::SmallVector<framework::InferShapeVarPtr,
phi::kOutputSmallVectorSize>
res;
auto it = var_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_out_->end(),
......
......@@ -36,6 +36,8 @@ DECLARE_bool(run_kp_kernel);
namespace paddle {
namespace imperative {
static const phi::Kernel empty_kernel;
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
......@@ -108,12 +110,13 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
ctx_(ctx),
kernel_type_(kernel_type),
func_(func),
dev_ctx_(dev_ctx) {}
dev_ctx_(dev_ctx),
pt_kernel_(empty_kernel) {}
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel,
platform::DeviceContext* dev_ctx)
: op_(op),
......@@ -122,7 +125,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
func_(nullptr),
dev_ctx_(dev_ctx),
run_phi_kernel_(true),
pt_kernel_signature_(kernel_signature),
pt_kernel_signature_(std::move(kernel_signature)),
pt_kernel_(pt_kernel) {}
template <typename VarType>
......@@ -170,7 +173,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx);
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
......@@ -200,8 +204,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().IsSelectKernelValid(
pt_kernel_name, try_pt_kernel_key)) {
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
......@@ -211,8 +215,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key);
auto& pt_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
if (pt_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
......@@ -227,9 +231,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
// TODO(chenweihang): using CPUKernel when miss device kernel case
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_kernel, dev_ctx);
return PreparedOp(op, ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
<< "` not found.";
......@@ -270,15 +273,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key, pt_kernel_signature,
pt_cpu_kernel, cpu_ctx);
return PreparedOp(op, ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
}
}
}
......@@ -505,7 +509,6 @@ static void PreparedOpRunPtImpl(
#endif
}
// TODO(chenweihang): add debug flags later
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
......
......@@ -154,7 +154,7 @@ class PreparedOp {
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& kernel_signature,
framework::KernelSignature&& kernel_signature,
const phi::Kernel& pt_kernel, platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
......@@ -206,7 +206,7 @@ class PreparedOp {
bool run_phi_kernel_{false};
bool run_kp_kernel_{false};
framework::KernelSignature pt_kernel_signature_;
phi::Kernel pt_kernel_;
const phi::Kernel& pt_kernel_;
};
const inline framework::Attribute& GetAttr(
......@@ -289,7 +289,7 @@ void BuildDygraphPhiKernelContext(
}
}
auto ins_vector = it->second;
auto& ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
......@@ -587,7 +587,7 @@ void PreparePhiData(const phi::Kernel& pt_kernel,
auto& ins_vector = ins.at(input_names[i]);
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto var = ins_vector[offset];
auto& var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) {
if (in_def.backend == phi::Backend::ALL_BACKEND) {
......
......@@ -226,6 +226,7 @@ bool AnalysisPredictor::PrepareScope(
status_is_cloned_ = true;
} else {
paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
// TODO(wilber): we need to release memory occupied by weights.
scope_.reset(new paddle::framework::Scope());
status_is_cloned_ = false;
......
......@@ -92,6 +92,7 @@ bool NativePaddlePredictor::Init(
"The sub_scope should not be nullptr."));
} else {
paddle::framework::InitDevices();
paddle::framework::InitDefaultKernelSignatureMap();
scope_.reset(new paddle::framework::Scope());
}
......
......@@ -517,10 +517,8 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs));
auto pg_ig_names = ctx->Outputs(kXGRAD);
std::vector<framework::InferShapeVarPtr> in_var_ptrs =
ctx->GetInputVarPtrs(kX);
std::vector<framework::InferShapeVarPtr> out_var_ptrs =
ctx->GetOutputVarPtrs(kXGRAD);
auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE_EQ(in_var_ptrs.size(), out_var_ptrs.size(),
platform::errors::InvalidArgument(
"The size of Inputs(X) must be the same as "
......
......@@ -63,10 +63,8 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
context->ShareLoD("MultiLevelRois", "FpnRois");
}
if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs =
context->GetInputVarPtrs("MultiLevelRois");
std::vector<framework::InferShapeVarPtr> score_inputs =
context->GetInputVarPtrs("MultiLevelScores");
auto roi_inputs = context->GetInputVarPtrs("MultiLevelRois");
auto score_inputs = context->GetInputVarPtrs("MultiLevelScores");
for (size_t i = 0; i < roi_inputs.size(); ++i) {
framework::Variable *roi_var =
BOOST_GET(framework::Variable *, roi_inputs[i]);
......
......@@ -60,6 +60,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/uva_utils.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/type_defs.h"
#include "paddle/phi/core/type_defs.h"
namespace paddle {
namespace pybind {
......@@ -2027,26 +2028,35 @@ void BindImperative(py::module *m_ptr) {
*(imperative::AmpOperators::Instance().GetMutableAllowOps()),
*(imperative::AmpOperators::Instance().GetMutableBlockOps()));
})
.def("_get_kernel_signature",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs) {
// TODO(xiongkun): move this function outside of tracer.
auto ins_map = ConvertToNameTensorMap(ins);
auto outs_map = ConvertToNameTensorMap(outs);
{
auto to_vector = [](paddle::SmallVector<std::string> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs);
auto kernelsig_ins = to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = to_vector(std::get<1>(ret.args));
auto kernelsig_outs = to_vector(std::get<2>(ret.args));
return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs);
}
})
.def(
"_get_kernel_signature",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
framework::AttributeMap attrs) {
// TODO(xiongkun): move this function outside of tracer.
auto ins_map = ConvertToNameTensorMap(ins);
auto outs_map = ConvertToNameTensorMap(outs);
{
auto input_to_vector =
[](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto output_to_vector =
[](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto attr_to_vector = [](paddle::SmallVector<const char *> &vec) {
return std::vector<std::string>(vec.begin(), vec.end());
};
auto ret = self.GetExpectedKernelSignature(type, ins_map,
outs_map, attrs);
auto kernelsig_ins = input_to_vector(std::get<0>(ret.args));
auto kernelsig_attrs = attr_to_vector(std::get<1>(ret.args));
auto kernelsig_outs = output_to_vector(std::get<2>(ret.args));
return std::make_tuple(kernelsig_ins, kernelsig_attrs,
kernelsig_outs);
}
})
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
......
......@@ -2907,6 +2907,8 @@ All parameter, weight, gradient are variables in Paddle.
framework::LoadOpMetaInfoAndRegisterOp(dso_name));
});
m.def("init_devices", []() { framework::InitDevices(); });
m.def("init_default_kernel_signatures",
[]() { framework::InitDefaultKernelSignatureMap(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_ascend", IsCompiledWithAscend);
m.def("is_compiled_with_rocm", IsCompiledWithROCM);
......
......@@ -15,6 +15,7 @@
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include <glog/logging.h>
#include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/phi/kernels/declarations.h"
namespace infrt {
......@@ -92,10 +93,10 @@ std::vector<PhiKernelDesc> GetCandidateKernels(
phi_kernel_desc.input_types.clear();
phi_kernel_desc.output_types.clear();
phi::KernelArgsDef args_def = kernel_key_map.at(kernel_key).args_def();
const paddle::SmallVector<phi::TensorArgDef>& input_arg =
args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs();
const paddle::SmallVector<phi::TensorArgDef, phi::kInputSmallVectorSize>&
input_arg = args_def.input_defs();
const paddle::SmallVector<phi::TensorArgDef, phi::kOutputSmallVectorSize>&
output_arg = args_def.output_defs();
for (auto tensor_arg : input_arg) {
phi_kernel_desc.input_types.emplace_back(ConvertPlaceFromPhi(tensor_arg));
}
......
......@@ -91,6 +91,7 @@ using ValueVariantType =
std::vector<::phi::DenseTensor*>,
paddle::experimental::ScalarBase<::phi::DenseTensor>,
paddle::experimental::IntArrayBase<::phi::DenseTensor>,
std::vector<const ::phi::MetaTensor*>,
std::vector<::phi::MetaTensor*>,
::phi::MetaConfig,
paddle::experimental::Backend,
......
......@@ -271,10 +271,10 @@ std::vector<Tensor> split_impl(const Tensor& x,
// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
if (num_or_sections.size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
out_number = num_or_sections.size();
}
std::vector<Tensor> out;
......@@ -449,54 +449,6 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
return api_output;
}
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "unbind API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
// Calculate the number of out tensors
auto input_shape = input.dims();
if (axis < 0) {
axis = input_shape.size() + axis;
}
auto out_num = input_shape[axis];
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_num);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_num);
for (int64_t i = 0; i < out_num; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
int,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
return out;
}
////////////////// Backward(grad) api impls //////////////////////
// TODO(chenweihang): the original sum grad op can support higher-level
......@@ -674,71 +626,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
return api_output;
}
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"concat_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "concat_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "concat_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
// std::unique_ptr<std::vector<phi::DenseTensor>>
auto dense_x = PrepareData(x, kernel.InputAt(0), {});
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
// Calculate the number of out tensors
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
std::vector<phi::MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_ptrs;
meta_x_ptrs.reserve(x.size());
for (const auto& t : *dense_x) {
meta_x.push_back(t);
meta_x_ptrs.push_back(&meta_x.back());
}
std::vector<phi::MetaTensor> meta_x_grad;
meta_x_grad.reserve(x.size());
std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
meta_x_grad_ptrs.reserve(x.size());
for (size_t i = 0; i < out_number; ++i) {
meta_x_grad.push_back(*dense_x_grad[i]);
meta_x_grad_ptrs.push_back(&meta_x_grad.back());
}
phi::UnchangedMultiInferMeta(meta_x_ptrs, meta_x_grad_ptrs);
std::vector<const phi::DenseTensor*> dense_x_ptr;
dense_x_ptr.reserve(x.size());
for (const auto& t : *dense_x) {
dense_x_ptr.push_back(&t);
}
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const phi::DenseTensor&,
const phi::Scalar&,
std::vector<phi::DenseTensor*>);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx, dense_x_ptr, *dense_out_grad, phi::Scalar(axis), dense_x_grad);
return x_grad;
}
Tensor imag_grad_impl(const Tensor& out_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
......@@ -795,328 +682,5 @@ Tensor real_grad_impl(const Tensor& out_grad) {
return out;
}
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"stack_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "stack_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "stack_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_out_grad = PrepareData(out_grad, kernel.InputAt(0), {});
size_t out_number = x.size();
std::vector<Tensor> x_grad;
auto dense_x_grad = SetKernelOutput(out_number, kernel_backend, &x_grad);
std::vector<phi::MetaTensor> meta_x_grad;
meta_x_grad.reserve(out_number);
std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
meta_x_grad_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_x_grad.push_back(dense_x_grad[i]);
meta_x_grad_ptrs.push_back(&meta_x_grad.back());
}
phi::StackGradInferMeta(
MakeMetaTensor(*dense_out_grad), axis, meta_x_grad_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
int axis,
std::vector<phi::DenseTensor*>);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, axis, dense_x_grad);
return x_grad;
}
std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}
auto x_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
inputs_metas[i] = &x_meta_vec[i];
}
// Calculate the number of out tensors
size_t out_number = inputs.size();
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MeshgridInferMeta(inputs_metas, meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, dense_outs);
return out;
}
std::vector<Tensor> meshgrid_grad_impl(
const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(inputs, outputs_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"meshgrid_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "meshgrid_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "meshgrid_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_inputs_vec = PrepareData(inputs, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_inputs(input_inputs_vec->size());
for (size_t i = 0; i < input_inputs.size(); ++i) {
input_inputs[i] = &input_inputs_vec->at(i);
}
auto input_outputs_grad_vec =
PrepareData(outputs_grad, kernel.InputAt(1), {});
std::vector<const phi::DenseTensor*> input_outputs_grad(
input_outputs_grad_vec->size());
for (size_t i = 0; i < input_outputs_grad.size(); ++i) {
input_outputs_grad[i] = &input_outputs_grad_vec->at(i);
}
size_t out_number = inputs.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
auto inputs_meta_vec = MakeMetaTensor(input_inputs);
std::vector<phi::MetaTensor*> inputs_metas(inputs_meta_vec.size());
for (size_t i = 0; i < inputs_meta_vec.size(); ++i) {
inputs_metas[i] = &inputs_meta_vec[i];
}
auto outputs_grad_meta_vec = MakeMetaTensor(input_outputs_grad);
std::vector<phi::MetaTensor*> outputs_grad_metas(
outputs_grad_meta_vec.size());
for (size_t i = 0; i < outputs_grad_meta_vec.size(); ++i) {
outputs_grad_metas[i] = &outputs_grad_meta_vec[i];
}
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MeshgridGradInferMeta(inputs_metas, outputs_grad_metas, meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const std::vector<const phi::DenseTensor*>&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_inputs, input_outputs_grad, kernel_out);
return api_output;
}
std::vector<Tensor> multi_dot_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x, out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
VLOG(6) << "multi_dot_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"multi_dot_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "multi_dot_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_x_vec = PrepareData(x, kernel.InputAt(0), {});
std::vector<const phi::DenseTensor*> input_x(input_x_vec->size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = &input_x_vec->at(i);
}
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
size_t out_number = input_x.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MultiDotGradInferMeta(
x_metas, MakeMetaTensor(*input_out_grad), meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const std::vector<const phi::DenseTensor*>&,
const phi::DenseTensor&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, *input_out_grad, kernel_out);
return api_output;
}
std::vector<Tensor> multiplex_grad_impl(const std::vector<Tensor>& inputs,
const Tensor& ids,
const Tensor& out_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(out_grad);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
VLOG(6) << "multiplex_grad API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"multiplex_grad", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "multiplex_grad API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_ids = PrepareData(ids, kernel.InputAt(0), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(1), {});
auto out_number = inputs.size();
std::vector<Tensor> api_output;
auto kernel_out = SetKernelOutput(out_number, kernel_backend, &api_output);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(kernel_out[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::MultiplexGradInferMeta(MakeMetaTensor(*input_ids),
MakeMetaTensor(*input_out_grad),
meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
std::vector<phi::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *input_ids, *input_out_grad, kernel_out);
return api_output;
}
} // namespace experimental
} // namespace paddle
......@@ -30,6 +30,20 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
Tensor conv2d_impl(const Tensor& input,
const Tensor& filter,
const std::vector<int>& strides,
......@@ -62,8 +76,6 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);
std::vector<Tensor> meshgrid_impl(const std::vector<Tensor>& inputs);
std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
......@@ -77,49 +89,14 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
bool multi_precision,
float rescale_grad);
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
////////////////// Backward(grad) api impls //////////////////////
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
/************************ backward api impl ***************************/
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis);
Tensor imag_grad_impl(const Tensor& x);
Tensor real_grad_impl(const Tensor& x);
std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis);
std::vector<Tensor> meshgrid_grad_impl(const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs_grad);
std::vector<Tensor> multi_dot_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad);
std::vector<Tensor> multiplex_grad_impl(const std::vector<Tensor>& inputs,
const Tensor& ids,
const Tensor& out_grad);
} // namespace experimental
} // namespace paddle
......@@ -76,6 +76,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors;
}
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}
......
......@@ -53,6 +53,9 @@ phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::DenseTensor*>& tensors);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
phi::MetaTensor MakeMetaTensor(const phi::StringTensor& tensor);
......
......@@ -96,6 +96,8 @@ class IntArrayBase {
template <typename OtherT>
IntArrayBase(const IntArrayBase<OtherT>& other) : array_(other.GetData()) {}
size_t size() const { return array_.size(); }
const std::vector<int64_t>& GetData() const { return array_; }
private:
......
......@@ -19,45 +19,33 @@ limitations under the License. */
#include <tuple>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
namespace phi {
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
inline std::string GradVarName(const std::string& var_name) {
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;
using KernelArgsTuple = std::tuple<paddle::SmallVector<const char*>,
paddle::SmallVector<const char*>,
paddle::SmallVector<const char*>>;
struct KernelSignature {
std::string name;
const char* name;
KernelArgsTuple args;
KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
KernelSignature(const char* kernel_name,
paddle::SmallVector<const char*>&& inputs,
paddle::SmallVector<const char*>&& attrs,
paddle::SmallVector<const char*>&& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const char* kernel_name,
const paddle::SmallVector<const char*>& inputs,
const paddle::SmallVector<const char*>& attrs,
const paddle::SmallVector<const char*>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
// TODO(chenweihang): add assign constructor to solve windows compile
......
......@@ -102,7 +102,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
}
}
std::string TransToPhiKernelName(const std::string& fluid_op_name) {
const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
}
......
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace phi {
std::string TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToPhiKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& phi_kernel_name);
Backend TransToPhiBackend(const phi::Place& place);
......
......@@ -26,6 +26,8 @@ limitations under the License. */
namespace phi {
const static std::string deprecated_kernel_name = "deprecated"; // NOLINT
const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op
......@@ -134,9 +136,9 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
}
std::string GetBaseKernelName(const std::string& op_type) const {
const std::string& GetBaseKernelName(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return "deprecated";
return deprecated_kernel_name;
}
auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) {
......@@ -150,7 +152,7 @@ class OpUtilsMap {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
......
......@@ -20,14 +20,12 @@ void InferMetaContext::SetMetaConfig(MetaConfig config) {
config_ = std::move(config);
}
void InferMetaContext::EmplaceBackInput(
std::shared_ptr<phi::MetaTensor> input) {
void InferMetaContext::EmplaceBackInput(MetaTensor input) {
int index = inputs_.size();
inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void InferMetaContext::EmplaceBackOutput(
std::shared_ptr<phi::MetaTensor> output) {
void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
int index = outputs_.size();
outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
......@@ -37,7 +35,7 @@ void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
}
void InferMetaContext::EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs) {
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = inputs_.size();
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
......@@ -45,7 +43,7 @@ void InferMetaContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end()));
}
void InferMetaContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs) {
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs) {
int index = outputs_.size();
output_range_.emplace_back(
std::pair<int, int>(index, index + outputs.size()));
......@@ -64,24 +62,25 @@ const std::pair<int, int>& InferMetaContext::OutputRangeAt(size_t idx) const {
const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
return inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt(
paddle::optional<const MetaTensor&> InferMetaContext::OptionalInputAt(
size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const phi::MetaTensor&>{static_cast<
const phi::MetaTensor&>(*input)}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
return input.initialized()
? paddle::optional<const MetaTensor&>{input}
: paddle::optional<const MetaTensor&>{paddle::none};
}
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor*> result;
std::vector<const MetaTensor*> InferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return result;
......@@ -91,12 +90,13 @@ paddle::optional<const std::vector<const MetaTensor*>>
InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
const auto& first = inputs_.at(start);
if (first) {
if (first.initialized()) {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.push_back(inputs_.at(i).get());
auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
}
return paddle::optional<const std::vector<const MetaTensor*>>(result);
......@@ -105,7 +105,8 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const {
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
auto& out = outputs_.at(idx);
return out.initialized() ? &out : nullptr;
}
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
......@@ -113,7 +114,8 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
std::vector<MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(outputs_.at(i).get());
auto& out = outputs_.at(i);
result.emplace_back(out.initialized() ? &out : nullptr);
}
return result;
}
......
......@@ -37,28 +37,28 @@ class InferMetaContext {
explicit InferMetaContext(MetaConfig config) : config_(config) {}
void SetMetaConfig(MetaConfig config);
void EmplaceBackInput(std::shared_ptr<phi::MetaTensor> input);
void EmplaceBackOutput(std::shared_ptr<phi::MetaTensor> output);
const MetaConfig& GetMetaConfig() const;
void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr);
void EmplaceBackInputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs);
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs);
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs);
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual const MetaTensor& InputAt(size_t idx) const;
virtual paddle::optional<const MetaTensor&> OptionalInputAt(size_t idx) const;
const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
paddle::optional<const std::vector<const phi::MetaTensor*>>
virtual std::vector<const MetaTensor*> InputsBetween(size_t start,
size_t end) const;
virtual paddle::optional<const std::vector<const MetaTensor*>>
OptionalInputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);
virtual MetaTensor* MutableOutputAt(size_t idx);
virtual std::vector<MetaTensor*> MutableOutputBetween(size_t start,
size_t end);
template <typename AttrType>
AttrType AttrAt(size_t idx) {
......@@ -73,19 +73,24 @@ class InferMetaContext {
}
}
private:
const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
virtual ~InferMetaContext() = default;
protected:
MetaConfig config_;
// NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor
// objects are all created in each round, so we have to use smart pointer
// here, maybe we can implemented a new InferMetaContext and a series utils
// specifically for fluid to avoid using shared_ptr
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs_;
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs_;
paddle::SmallVector<paddle::any> attrs_;
paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<std::pair<int, int>> input_range_;
paddle::SmallVector<std::pair<int, int>> output_range_;
paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
paddle::SmallVector<std::pair<int, int>, phi::kOutputSmallVectorSize>
output_range_;
private:
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs_;
paddle::SmallVector<MetaTensor, phi::kOutputSmallVectorSize> outputs_;
};
#define PD_INFER_META(...) \
......@@ -159,7 +164,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> {
struct InferMetaFnCallHelper<const std::vector<const MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
......@@ -167,7 +172,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor*> arg =
std::vector<const MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
......
......@@ -79,7 +79,7 @@ void KernelContext::EmplaceBackAttr(paddle::any attr) {
void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < input_range_.size()) {
input_range_[idx] = range;
input_range_[idx] = std::move(range);
} else if (idx == input_range_.size()) {
input_range_.emplace_back(range);
} else {
......@@ -93,7 +93,7 @@ void KernelContext::AssignInputRange(std::pair<int, int>&& range, size_t idx) {
void KernelContext::AssignOutputRange(std::pair<int, int>&& range, size_t idx) {
if (idx < output_range_.size()) {
output_range_[idx] = range;
output_range_[idx] = std::move(range);
} else if (idx == output_range_.size()) {
output_range_.emplace_back(range);
} else {
......
......@@ -19,6 +19,8 @@
namespace phi {
const static Kernel empty_kernel; // NOLINT
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---|
......@@ -37,15 +39,15 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory;
}
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
return Kernel();
return empty_kernel;
}
auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end()) {
return Kernel();
return empty_kernel;
}
return kernel_iter->second;
}
......@@ -59,8 +61,8 @@ KernelKeyMap KernelFactory::SelectKernelMap(
return iter->second;
}
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const {
bool KernelFactory::HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
......@@ -128,6 +130,16 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
KernelKey(backend, layout, dtype));
}
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
return iter->second.cbegin()->second.args_def();
}
// print kernel info with json format:
// {
// "(CPU, Undefined(AnyLayout), complex64)": {
......
......@@ -151,30 +151,38 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index));
}
const paddle::SmallVector<TensorArgDef>& input_defs() const {
const paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs()
const {
return input_defs_;
}
const paddle::SmallVector<TensorArgDef>& output_defs() const {
const paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs()
const {
return output_defs_;
}
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const {
const paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>&
attribute_defs() const {
return attribute_defs_;
}
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; }
paddle::SmallVector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
return input_defs_;
}
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; }
paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
return output_defs_;
}
paddle::SmallVector<AttributeArgDef>& attribute_defs() {
paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize>& attribute_defs() {
return attribute_defs_;
}
private:
paddle::SmallVector<TensorArgDef> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}};
paddle::SmallVector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
paddle::SmallVector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
paddle::SmallVector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
{}};
};
class Kernel {
......@@ -209,7 +217,7 @@ class Kernel {
TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }
bool IsValid() { return fn_ != nullptr; }
bool IsValid() const { return fn_ != nullptr; }
private:
KernelFn fn_{nullptr};
......@@ -246,14 +254,17 @@ class KernelFactory {
DataLayout layout,
DataType dtype) const;
bool IsSelectKernelValid(const std::string& kernel_name,
const KernelKey& kernel_key) const;
bool HasKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
const Kernel& SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
const KernelArgsDef& GetFirstKernelArgsDef(
const std::string& kernel_name) const;
private:
KernelFactory() = default;
......
......@@ -148,4 +148,6 @@ void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
}
}
bool MetaTensor::initialized() const { return tensor_ != nullptr; }
} // namespace phi
......@@ -45,10 +45,10 @@ class MetaTensor {
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT
MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete;
MetaTensor& operator=(MetaTensor&&) = delete;
MetaTensor& operator=(MetaTensor&&) = default;
MetaTensor(const MetaTensor&) = default;
MetaTensor& operator=(const MetaTensor&) = default;
virtual ~MetaTensor() = default;
......@@ -64,6 +64,8 @@ class MetaTensor {
virtual void share_meta(const MetaTensor& meta_tensor);
virtual void share_dims(const MetaTensor& meta_tensor);
virtual bool initialized() const;
private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
......
......@@ -22,7 +22,7 @@ class Kernel;
class KernelKey;
class KernelArgsDef;
class KernelContext;
class KernelSignature;
struct KernelSignature;
class ArgumentMappingContext;
class InferMetaContext;
......@@ -35,4 +35,9 @@ using ArgumentMappingFn =
std::function<KernelSignature(const ArgumentMappingContext&)>;
using InferMetaFn = void (*)(InferMetaContext* ctx);
// Global SmallVector size setting
constexpr size_t kInputSmallVectorSize = 10U;
constexpr size_t kAttrSmallVectorSize = 10U;
constexpr size_t kOutputSmallVectorSize = 5U;
} // namespace phi
......@@ -315,8 +315,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x);
}
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) {
PADDLE_ENFORCE_GT(outputs_grad.size(),
1,
......@@ -329,7 +329,7 @@ void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
}
}
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x,
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad) {
PADDLE_ENFORCE_EQ(
......
......@@ -151,11 +151,11 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
bool adaptive,
MetaTensor* dx);
void MeshgridGradInferMeta(const std::vector<MetaTensor*>& inputs,
const std::vector<MetaTensor*>& outputs_grad,
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad);
void MultiDotGradInferMeta(const std::vector<MetaTensor*>& x,
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad);
......
......@@ -21,7 +21,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) {
std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors) {
std::vector<DDim> dims;
dims.reserve(tensors.size());
for (const MetaTensor* tensor : tensors) {
......@@ -148,7 +149,7 @@ void AdamaxInferMeta(const MetaTensor& param,
inf_norm_out->set_dtype(inf_norm.dtype());
}
void AddNInferMeta(const std::vector<MetaTensor*>& x,
void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config) {
auto N = x.size();
......@@ -511,7 +512,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) {
int target_rank = 0;
const auto& input_dims = GetMetaTensorsDim(x);
......@@ -565,7 +566,7 @@ void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
}
}
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config) {
......@@ -1357,7 +1358,7 @@ void InterpolateInferMeta(
}
}
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......@@ -1420,7 +1421,8 @@ void MomentumInferMeta(const MetaTensor& param,
}
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
const size_t inputs_num = inputs_dims.size();
......@@ -1493,7 +1495,7 @@ void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
out->share_lod(*x.at(0));
}
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out) {
PADDLE_ENFORCE_NE(
......@@ -1672,8 +1674,8 @@ void RmspropInferMeta(const MetaTensor& param,
}
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
......@@ -1779,7 +1781,7 @@ void SGDInferMeta(const MetaTensor& param,
param_out->set_dtype(param.dtype());
}
void StackInferMeta(const std::vector<MetaTensor*>& x,
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out) {
PADDLE_ENFORCE_GT(x.size(),
......@@ -1825,7 +1827,7 @@ void StackInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) {
out[i]->share_meta(*x[i]);
......
......@@ -35,7 +35,8 @@ namespace phi {
//
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors);
std::vector<DDim> GetMetaTensorsDim(
const std::vector<const MetaTensor*>& tensors);
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
......@@ -68,7 +69,7 @@ void AdamaxInferMeta(const MetaTensor& param,
MetaTensor* moment_out,
MetaTensor* inf_norm_out);
void AddNInferMeta(const std::vector<MetaTensor*>& x,
void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
......@@ -124,10 +125,10 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void BroadcastTensorsInferMeta(const std::vector<MetaTensor*>& x,
void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out);
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config = MetaConfig());
......@@ -178,7 +179,7 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
void MomentumInferMeta(const MetaTensor& param,
......@@ -196,9 +197,10 @@ void MomentumInferMeta(const MetaTensor& param,
MetaTensor* velocity_out,
MetaTensor* master_param_out);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out);
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
void MultiplexInferMeta(const std::vector<const MetaTensor*>& ins,
const MetaTensor& ids,
MetaTensor* out);
......@@ -227,8 +229,8 @@ void RmspropInferMeta(const MetaTensor& param,
MetaTensor* mean_grad_out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
......@@ -251,11 +253,11 @@ void SGDInferMeta(const MetaTensor& param,
MetaTensor* param_out,
MetaTensor* master_param_out);
void StackInferMeta(const std::vector<MetaTensor*>& x,
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out);
void UnchangedMultiInferMeta(const std::vector<MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out);
void WarpctcInferMeta(const MetaTensor& logits,
......
......@@ -32,7 +32,7 @@ DenseTensor Concat(const Context& dev_ctx,
const Scalar& axis) {
std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr;
std::vector<const MetaTensor*> meta_x_ptr;
for (const auto* t : x) {
meta_x.emplace_back(*t);
meta_x_ptr.push_back(&meta_x.back());
......
......@@ -21,8 +21,7 @@ KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("abs_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature AbsDoubleGradOpArgumentMapping(
......
......@@ -19,26 +19,22 @@ namespace phi {
#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
{"X", GradVarName("Out")}, \
{attrs}, \
{GradVarName("X")}); \
return KernelSignature( \
op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \
{"Out", GradVarName("Out")}, \
{attrs}, \
{GradVarName("X")}); \
return KernelSignature( \
op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"}); \
}
#define comma ,
......@@ -165,15 +161,12 @@ KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature LogitGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"logit_grad", {"X", GradVarName("Out")}, {"eps"}, {GradVarName("X")});
return KernelSignature("logit_grad", {"X", "Out@GRAD"}, {"eps"}, {"X@GRAD"});
}
KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu_grad",
{"X", "Out", GradVarName("Out")},
{"alpha"},
{GradVarName("X")});
return KernelSignature(
"elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"});
}
KernelSignature EluDoubleGradOpArgumentMapping(
......@@ -198,13 +191,11 @@ KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_grad",
{"X", GradVarName("Out")},
{"FactorTensor"},
{GradVarName("X")});
return KernelSignature(
"pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"});
} else {
return KernelSignature(
"pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")});
"pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"});
}
}
......
......@@ -17,11 +17,10 @@
namespace phi {
KernelSignature AddmmGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"addmm_grad",
{"Input", "X", "Y", GradVarName("Out")},
{"Alpha", "Beta"},
{GradVarName("Input"), GradVarName("X"), GradVarName("Y")});
return KernelSignature("addmm_grad",
{"Input", "X", "Y", "Out@GRAD"},
{"Alpha", "Beta"},
{"Input@GRAD", "X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature ArgsortGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("argsort_grad",
{"Indices", "X", GradVarName("Out")},
{"Indices", "X", "Out@GRAD"},
{"axis", "descending"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@
namespace phi {
KernelSignature Atan2GradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("atan2_grad",
{"X1", "X2", GradVarName("Out")},
{},
{GradVarName("X1"), GradVarName("X2")});
return KernelSignature(
"atan2_grad", {"X1", "X2", "Out@GRAD"}, {}, {"X1@GRAD", "X2@GRAD"});
}
} // namespace phi
......
......@@ -57,27 +57,26 @@ KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"batch_norm_grad",
{
"X",
"Scale",
"Bias",
"Mean",
"Variance",
"SavedMean",
"SavedVariance",
"ReserveSpace",
GradVarName("Y"),
},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
return KernelSignature("batch_norm_grad",
{
"X",
"Scale",
"Bias",
"Mean",
"Variance",
"SavedMean",
"SavedVariance",
"ReserveSpace",
"Y@GRAD",
},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
KernelSignature BatchNormGradGradOpArgumentMapping(
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature BCELossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bce_loss_grad",
{"X", "Label", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"bce_loss_grad", {"X", "Label", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,12 +25,9 @@ KernelSignature BilinearTensorProductOpArgumentMapping(
KernelSignature BilinearTensorProductGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bilinear_tensor_product_grad",
{"X", "Y", "Weight", GradVarName("Out")},
{"X", "Y", "Weight", "Out@GRAD"},
{},
{GradVarName("X"),
GradVarName("Y"),
GradVarName("Weight"),
GradVarName("Bias")});
{"X@GRAD", "Y@GRAD", "Weight@GRAD", "Bias@GRAD"});
}
} // namespace phi
......
......@@ -19,7 +19,7 @@ namespace phi {
KernelSignature BroadcastTensorsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"broadcast_tensors_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
"broadcast_tensors_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CholeskyGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_grad",
{"Out", GradVarName("Out")},
{"upper"},
{GradVarName("X")});
return KernelSignature(
"cholesky_grad", {"Out", "Out@GRAD"}, {"upper"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature CholeskySolveGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cholesky_solve_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"X", "Y", "Out", "Out@GRAD"},
{"upper"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,7 +18,7 @@
namespace phi {
KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> attr_names;
paddle::SmallVector<std::string, kAttrSmallVectorSize> attr_names;
attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min");
attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max");
if (ctx.IsDenseTensorInput("X")) {
......@@ -57,27 +57,19 @@ KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "Max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"Min", "Max"}, {"X@GRAD"});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"Min", "max"}, {"X@GRAD"});
}
} else {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "Max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"min", "Max"}, {"X@GRAD"});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "max"},
{GradVarName("X")});
return KernelSignature(
"clip_grad", {"X", "Out@GRAD"}, {"min", "max"}, {"X@GRAD"});
}
}
}
......
......@@ -17,13 +17,11 @@
namespace phi {
KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"real_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,15 +25,11 @@ KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ConcatGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat_grad",
{"X", {GradVarName("Out")}},
{"AxisTensor"},
{{GradVarName("X")}});
return KernelSignature(
"concat_grad", {"X", {"Out@GRAD"}}, {"AxisTensor"}, {{"X@GRAD"}});
}
return KernelSignature("concat_grad",
{"X", {GradVarName("Out")}},
{"axis"},
{{GradVarName("X")}});
return KernelSignature(
"concat_grad", {"X", {"Out@GRAD"}}, {"axis"}, {{"X@GRAD"}});
}
} // namespace phi
......
......@@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -56,7 +56,7 @@ KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dDoubleGradOpArgumentMapping(
......
......@@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -43,7 +43,7 @@ KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv3dDoubleGradOpArgumentMapping(
......
......@@ -34,7 +34,7 @@ KernelSignature Conv2dTransposeOpArgumentMapping(
KernelSignature Conv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -43,7 +43,7 @@ KernelSignature Conv2dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping(
......@@ -79,7 +79,7 @@ KernelSignature Conv3dTransposeOpArgumentMapping(
KernelSignature Conv3dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("conv3d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -88,7 +88,7 @@ KernelSignature Conv3dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
......@@ -109,7 +109,7 @@ KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_transpose_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
......@@ -118,7 +118,7 @@ KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
"groups",
"dilations",
"data_format"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
} // namespace phi
......
......@@ -21,10 +21,8 @@ KernelSignature CrossOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature CrossGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cross_grad",
{"X", "Y", GradVarName("Out")},
{"dim"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"cross_grad", {"X", "Y", "Out@GRAD"}, {"dim"}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature CumprodGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("cumprod_grad",
{"X", "Out", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
return KernelSignature(
"cumprod_grad", {"X", "Out", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -33,17 +33,14 @@ KernelSignature DeformableConvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"deformable_conv_grad",
{"Input", "Offset", "Filter", "Mask", GradVarName("Output")},
{"Input", "Offset", "Filter", "Mask", "Output@GRAD"},
{"strides",
"paddings",
"dilations",
"deformable_groups",
"groups",
"im2col_step"},
{GradVarName("Input"),
GradVarName("Offset"),
GradVarName("Filter"),
GradVarName("Mask")});
{"Input@GRAD", "Offset@GRAD", "Filter@GRAD", "Mask@GRAD"});
}
} // namespace phi
......
......@@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping(
KernelSignature DepthwiseConv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("depthwise_conv2d_grad",
{"Input", "Filter", GradVarName("Output")},
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
......@@ -47,7 +47,7 @@ KernelSignature DepthwiseConv2dGradOpArgumentMapping(
"workspace_size_MB",
"exhaustive_search",
"fuse_relu_before_depthwise_conv"},
{GradVarName("Input"), GradVarName("Filter")});
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping(
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature DeterminantGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("determinant_grad",
{"Input", "Out", GradVarName("Out")},
{},
{GradVarName("Input")});
return KernelSignature(
"determinant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"});
}
} // namespace phi
......
......@@ -22,7 +22,7 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")});
"diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature DiagonalGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("diagonal_grad",
{"Input", GradVarName("Out")},
{"Input", "Out@GRAD"},
{"offset", "axis1", "axis2"},
{GradVarName("Input")});
{"Input@GRAD"});
}
} // namespace phi
......
......@@ -18,8 +18,7 @@ namespace phi {
KernelSignature DigammaGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"digamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("digamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi {
KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dist_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"p"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"dist_grad", {"X", "Y", "Out", "Out@GRAD"}, {"p"}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@ limitations under the License. */
namespace phi {
KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("dot_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"dot_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -27,9 +27,9 @@ KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature DropoutGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_grad",
{"Mask", GradVarName("Out")},
{"Mask", "Out@GRAD"},
{"dropout_prob", "is_test", "dropout_implementation"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,13 +17,11 @@
namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigh_grad",
{"Eigenvalues",
"Eigenvectors",
GradVarName("Eigenvalues"),
GradVarName("Eigenvectors")},
{},
{GradVarName("X")});
return KernelSignature(
"eigh_grad",
{"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
{},
{"X@GRAD"});
}
} // namespace phi
......
......@@ -106,10 +106,8 @@ KernelSignature ElementwisePowOpArgumentMapping(
KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"add_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
......@@ -128,10 +126,8 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("subtract_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"subtract_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
......@@ -143,17 +139,15 @@ KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
KernelSignature ElementwiseDivGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("divide_grad",
{"X", "Y", "Out", GradVarName("Out")},
{"X", "Y", "Out", "Out@GRAD"},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
......@@ -161,15 +155,13 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"},
{"axis"},
{GradVarName("Y"), "DOut", "DDOut"});
{"Y@GRAD", "DOut", "DDOut"});
}
KernelSignature ElementwiseMulGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"multiply_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseFMaxOpArgumentMapping(
......@@ -184,10 +176,8 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmax_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
......@@ -195,7 +185,7 @@ KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
return KernelSignature("multiply_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"axis"},
{GradVarName("X"), GradVarName("Y"), "DDOut"});
{"X@GRAD", "Y@GRAD", "DDOut"});
}
KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
......@@ -209,25 +199,21 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("maximum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"maximum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("minimum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad",
{"X", "Y", GradVarName("Out")},
{"X", "Y", "Out@GRAD"},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -30,26 +30,26 @@ KernelSignature EmbeddingGradOpArgumentMapping(
if (ctx.IsDenseTensorInput("W")) {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
} else {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
}
} else {
if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
return KernelSignature("sparse_weight_embedding_sparse_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"Ids", "W", "Out@GRAD"},
{"padding_idx"},
{GradVarName("W")});
{"W@GRAD"});
}
}
}
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature ErfGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"erf_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("erf_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("erfinv_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -22,10 +22,8 @@ KernelSignature ExpandAsOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature ExpandAsGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("expand_as_grad",
{"X", GradVarName("Out")},
{"target_shape"},
{GradVarName("X")});
return KernelSignature(
"expand_as_grad", {"X", "Out@GRAD"}, {"target_shape"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -39,20 +39,14 @@ KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
}
if (ctx.HasInput("Shape")) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"Shape"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"});
} else if (ctx.InputSize("expand_shapes_tensor") > 0) {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"expand_shapes_tensor"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"expand_shapes_tensor"}, {"X@GRAD"});
} else {
return KernelSignature("expand_grad",
{"X", GradVarName("Out")},
{"shape"},
{GradVarName("X")});
return KernelSignature(
"expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"});
}
}
......
......@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
"flatten_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,9 +25,9 @@ KernelSignature FrobeniusNormOpArgumentMapping(
KernelSignature FrobeniusNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("frobenius_norm_grad",
{"X", "Out", GradVarName("Out")},
{"X", "Out", "Out@GRAD"},
{"dim", "keep_dim", "reduce_all"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,25 +17,23 @@
namespace phi {
KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gather_nd_grad",
{"X", "Index", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_grad",
{"Ids", "Updates", GradVarName("Out")},
{"Ids", "Updates", "Out@GRAD"},
{"overwrite"},
{GradVarName("X"), GradVarName("Updates")});
{"X@GRAD", "Updates@GRAD"});
}
KernelSignature ScatterNdAddGradArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("scatter_nd_add_grad",
{"Index", "Updates", GradVarName("Out")},
{"Index", "Updates", "Out@GRAD"},
{},
{GradVarName("X"), GradVarName("Updates")});
{"X@GRAD", "Updates@GRAD"});
}
} // namespace phi
......
......@@ -27,14 +27,14 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"X", "Index", "Out@GRAD"},
{"Axis", "overwrite"},
{GradVarName("X")});
{"X@GRAD"});
} else {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"X", "Index", "Out@GRAD"},
{"axis", "overwrite"},
{GradVarName("X")});
{"X@GRAD"});
}
}
......
......@@ -21,10 +21,8 @@ KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu_grad",
{"X", GradVarName("Out")},
{"approximate"},
{GradVarName("X")});
return KernelSignature(
"gelu_grad", {"X", "Out@GRAD"}, {"approximate"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -28,9 +28,9 @@ KernelSignature GraphSendRecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_recv_grad",
{"X", "Src_index", "Dst_index", "Out", "Dst_count", GradVarName("Out")},
{"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"pool_type"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -27,9 +27,9 @@ KernelSignature GridSamplerOpArgumentMapping(
KernelSignature GridSamplerGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("grid_sample_grad",
{"X", "Grid", GradVarName("Output")},
{"X", "Grid", "Output@GRAD"},
{"mode", "padding_mode", "align_corners"},
{GradVarName("X"), GradVarName("Grid")});
{"X@GRAD", "Grid@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("gumbel_softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
return KernelSignature(
"gumbel_softmax_grad", {"Out", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -32,44 +32,42 @@ KernelSignature HierarchicalSigmoidOpArgumentMapping(
KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad",
{"X",
"W",
"Label",
"PathTable",
"PathCode",
"Bias",
"PreOut",
GradVarName("Out")},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad_sr",
{"X",
"W",
"Label",
"PathTable",
"PathCode",
"Bias",
"PreOut",
GradVarName("Out")},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
if (ctx.IsDenseTensorOutput("W@GRAD")) {
return KernelSignature("hierarchical_sigmoid_grad",
{"X",
"W",
"Label",
"PathTable",
"PathCode",
"Bias",
"PreOut",
"Out@GRAD"},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{"X@GRAD", "W@GRAD", "Bias@GRAD"});
} else if (ctx.IsSelectedRowsOutput("W@GRAD")) {
return KernelSignature("hierarchical_sigmoid_grad_sr",
{"X",
"W",
"Label",
"PathTable",
"PathCode",
"Bias",
"PreOut",
"Out@GRAD"},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{"X@GRAD", "W@GRAD", "Bias@GRAD"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
......
......@@ -24,9 +24,9 @@ KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature HuberLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("huber_loss_grad",
{"Residual", GradVarName("Out")},
{"Residual", "Out@GRAD"},
{"delta"},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSampleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_sample_grad",
{"X", "Index", GradVarName("Out")},
{},
{GradVarName("X")});
return KernelSignature(
"index_sample_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -18,10 +18,8 @@ namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad",
{"X", "Index", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
return KernelSignature(
"index_select_grad", {"X", "Index", "Out@GRAD"}, {"dim"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -92,81 +92,76 @@ KernelSignature BicubicInterpOpArgumentMapping(
KernelSignature BilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
return KernelSignature("bilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"X@GRAD"});
}
KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
return KernelSignature("nearest_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"X@GRAD"});
}
KernelSignature TrilinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
return KernelSignature("trilinear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"X@GRAD"});
}
KernelSignature LinearInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
return KernelSignature("linear_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"X@GRAD"});
}
KernelSignature BicubicInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", GradVarName("Out")},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{GradVarName("X")});
return KernelSignature("bicubic_interp_v2_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
"out_h",
"out_w",
"scale",
"interp_method",
"align_corners",
"align_mode"},
{"X@GRAD"});
}
} // namespace phi
......
......@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KLDivLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("kldiv_loss_grad",
{"X", "Target", GradVarName("Loss")},
{"X", "Target", "Loss@GRAD"},
{"reduction"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -17,10 +17,8 @@
namespace phi {
KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron_grad",
{"X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
return KernelSignature(
"kron_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -20,9 +20,9 @@ namespace phi {
KernelSignature KthvalueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("kthvalue_grad",
{"X", "Indices", GradVarName("Out")},
{"X", "Indices", "Out@GRAD"},
{"k", "axis", "keepdim"},
{GradVarName("X")});
{"X@GRAD"});
}
} // namespace phi
......
......@@ -24,10 +24,8 @@ KernelSignature LabelSmoothOpArgumentMapping(
KernelSignature LabelSmoothGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("label_smooth_grad",
{GradVarName("Out")},
{"epsilon"},
{GradVarName("X")});
return KernelSignature(
"label_smooth_grad", {"Out@GRAD"}, {"epsilon"}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -25,11 +25,10 @@ KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LayerNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"layer_norm_grad",
{"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")},
{"epsilon", "begin_norm_axis", "is_test"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
return KernelSignature("layer_norm_grad",
{"X", "Mean", "Variance", "Scale", "Bias", "Y@GRAD"},
{"epsilon", "begin_norm_axis", "is_test"},
{"X@GRAD", "Scale@GRAD", "Bias@GRAD"});
}
} // namespace phi
......
......@@ -22,9 +22,9 @@ KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lerp_grad",
{"X", "Y", "Weight", "Out", GradVarName("Out")},
{"X", "Y", "Weight", "Out", "Out@GRAD"},
{},
{GradVarName("X"), GradVarName("Y")});
{"X@GRAD", "Y@GRAD"});
}
} // namespace phi
......
......@@ -17,8 +17,7 @@
namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"lgamma_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
......
......@@ -19,9 +19,9 @@ namespace phi {
KernelSignature LogLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_loss_grad",
{"Predicted", "Labels", GradVarName("Loss")},
{"Predicted", "Labels", "Loss@GRAD"},
{"epsilon"},
{GradVarName("Predicted")});
{"Predicted@GRAD"});
}
} // namespace phi
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册