未验证 提交 ddb1e23f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add HasAttr for ArgumentMappingContext (#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo
上级 e07420b9
......@@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.Attrs().GetAttr(name);
return GetAttrValue(attr);
......@@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
pten::InferMetaContext infer_meta_context(ctx->IsRuntime());
auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);
// TODO(chenweihang): support multiple inputs and outputs
// TODO(chenweihang): support multiple inputs and outputs later
pten::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
if (ctx->HasInput(in_name)) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackInput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) {
if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else {
// do nothing, skip useless attrs now
// TODO(chenweihang): support other attr type later and throw error
// if attr is cannot parsed
}
} else {
// do nothing
}
}
for (auto& out_name : output_names) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
// TODO(chenweihang): support attrs later
return infer_meta_context;
}
......
......@@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput(
return out[0] != nullptr;
}
bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool InterpretercoreInferShapeContext::HasInputs(
const std::string& name) const {
const auto& ins = ctx_.inputs;
......
......@@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override;
......
......@@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string &name) const override;
bool HasAttr(const std::string &name) const override;
bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override;
......@@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
return block_.HasVarRecursive(output_names[0]);
}
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
return op_.HasAttr(name);
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false;
......
......@@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return op_.HasAttr(name);
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
......
......@@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name);
}
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.GetAttr(name);
return GetAttrValue(attr);
......
......@@ -61,6 +61,7 @@ class InferShapeContext {
virtual ~InferShapeContext() = default;
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
virtual bool HasAttr(const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0;
......
......@@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return out[0] != nullptr;
}
bool HasAttr(const std::string& name) const override {
return attrs_->count(name) > 0 || default_attrs_->count(name) > 0;
}
bool HasInputs(const std::string& name) const override {
auto it = var_map_in_->find(name);
if (it == var_map_in_->end() || it->second.empty()) {
......
......@@ -16,6 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/backward.h"
namespace paddle {
namespace operators {
......@@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "matmul_v2");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(y_grad_name)) {
context->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
......@@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor,
PT_INFER_META(pten::MatmulGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>);
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
MatMulV2GradInferShapeFunctor);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad,
ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>,
......
......@@ -77,6 +77,7 @@ class ArgumentMappingContext {
virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
// now we can't use Attribute here, it will cause pten relay on
// boost::variant and BlockDesc
......
......@@ -146,6 +146,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
// TODO(chenweihang): support other attr type later
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
......
......@@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy) {
dx->share_meta(x);
dy->share_meta(y);
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
}
} // namespace pten
......@@ -17,10 +17,17 @@ limitations under the License. */
namespace pten {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y", "use_addto"},
{GradVarName("X"), GradVarName("Y")});
} else {
return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")});
}
}
KernelSignature MatmulDoubleGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册