未验证 提交 ddb1e23f 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add HasAttr for ArgumentMappingContext (#39464)

* add has_attr for arg map context

* skip useless attr now

* skip attr if not exists

* fix typo
上级 e07420b9
...@@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name); return ctx_.HasOutput(name);
} }
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override { paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.Attrs().GetAttr(name); auto& attr = ctx_.Attrs().GetAttr(name);
return GetAttrValue(attr); return GetAttrValue(attr);
...@@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
pten::InferMetaContext infer_meta_context(ctx->IsRuntime()); pten::InferMetaContext infer_meta_context(ctx->IsRuntime());
auto& input_names = std::get<0>(signature.args); auto& input_names = std::get<0>(signature.args);
auto& attr_names = std::get<1>(signature.args);
auto& output_names = std::get<2>(signature.args); auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);
// TODO(chenweihang): support multiple inputs and outputs // TODO(chenweihang): support multiple inputs and outputs later
pten::InferMetaContext infer_mete_context; pten::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) { for (auto& in_name : input_names) {
if (ctx->HasInput(in_name)) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>( infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackInput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (auto& attr_name : attr_names) {
if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else {
// do nothing, skip useless attrs now
// TODO(chenweihang): support other attr type later and throw error
// if attr is cannot parsed
} }
} else {
// do nothing
}
}
for (auto& out_name : output_names) { for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>( infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
} }
// TODO(chenweihang): support attrs later
return infer_meta_context; return infer_meta_context;
} }
......
...@@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput( ...@@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput(
return out[0] != nullptr; return out[0] != nullptr;
} }
bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
bool InterpretercoreInferShapeContext::HasInputs( bool InterpretercoreInferShapeContext::HasInputs(
const std::string& name) const { const std::string& name) const {
const auto& ins = ctx_.inputs; const auto& ins = ctx_.inputs;
......
...@@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override; bool HasOutput(const std::string& name) const override;
bool HasAttr(const std::string& name) const override;
bool HasInputs(const std::string& name) const override; bool HasInputs(const std::string& name) const override;
bool HasOutputs(const std::string& name) const override; bool HasOutputs(const std::string& name) const override;
......
...@@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string &name) const override; bool HasOutput(const std::string &name) const override;
bool HasAttr(const std::string &name) const override;
bool HasInputs(const std::string &name) const override; bool HasInputs(const std::string &name) const override;
bool HasOutputs(const std::string &name) const override; bool HasOutputs(const std::string &name) const override;
...@@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const { ...@@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
return block_.HasVarRecursive(output_names[0]); return block_.HasVarRecursive(output_names[0]);
} }
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
return op_.HasAttr(name);
}
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
if (op_.Inputs().find(name) == op_.Inputs().end()) { if (op_.Inputs().find(name) == op_.Inputs().end()) {
return false; return false;
......
...@@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return out[0] != nullptr; return out[0] != nullptr;
} }
bool HasAttr(const std::string& name) const override {
return op_.HasAttr(name);
}
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs; const auto& ins = ctx_.inputs;
auto it = ins.find(name); auto it = ins.find(name);
......
...@@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { ...@@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.HasOutput(name); return ctx_.HasOutput(name);
} }
bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}
paddle::any Attr(const std::string& name) const override { paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.GetAttr(name); auto& attr = ctx_.GetAttr(name);
return GetAttrValue(attr); return GetAttrValue(attr);
......
...@@ -61,6 +61,7 @@ class InferShapeContext { ...@@ -61,6 +61,7 @@ class InferShapeContext {
virtual ~InferShapeContext() = default; virtual ~InferShapeContext() = default;
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
virtual bool HasAttr(const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType( virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0; const std::string &name) const = 0;
......
...@@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return out[0] != nullptr; return out[0] != nullptr;
} }
bool HasAttr(const std::string& name) const override {
return attrs_->count(name) > 0 || default_attrs_->count(name) > 0;
}
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
auto it = var_map_in_->find(name); auto it = var_map_in_->find(name);
if (it == var_map_in_->end() || it->second.empty()) { if (it == var_map_in_->end() || it->second.empty()) {
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/backward.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "matmul_v2");
auto x_dims = context->GetInputDim("X");
auto y_dims = context->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (context->HasOutput(x_grad_name)) {
context->SetOutputDim(x_grad_name, x_dims);
}
if (context->HasOutput(y_grad_name)) {
context->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
...@@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ...@@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>, ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>); ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor,
PT_INFER_META(pten::MatmulGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>, ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>); ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
MatMulV2GradInferShapeFunctor);
REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad, REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad,
ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>, ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>,
......
...@@ -77,6 +77,7 @@ class ArgumentMappingContext { ...@@ -77,6 +77,7 @@ class ArgumentMappingContext {
virtual bool HasInput(const std::string& name) const = 0; virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0; virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;
// now we can't use Attribute here, it will cause pten relay on // now we can't use Attribute here, it will cause pten relay on
// boost::variant and BlockDesc // boost::variant and BlockDesc
......
...@@ -146,6 +146,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -146,6 +146,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
} }
}; };
// TODO(chenweihang): support other attr type later
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
......
...@@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x, ...@@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x,
bool transpose_y, bool transpose_y,
MetaTensor* dx, MetaTensor* dx,
MetaTensor* dy) { MetaTensor* dy) {
if (dx) {
dx->share_meta(x); dx->share_meta(x);
}
if (dy) {
dy->share_meta(y); dy->share_meta(y);
}
} }
} // namespace pten } // namespace pten
...@@ -17,10 +17,17 @@ limitations under the License. */ ...@@ -17,10 +17,17 @@ limitations under the License. */
namespace pten { namespace pten {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad",
{"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y", "use_addto"},
{GradVarName("X"), GradVarName("Y")});
} else {
return KernelSignature("matmul_grad", return KernelSignature("matmul_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", GradVarName("Out")},
{"trans_x", "trans_y"}, {"trans_x", "trans_y"},
{GradVarName("X"), GradVarName("Y")}); {GradVarName("X"), GradVarName("Y")});
}
} }
KernelSignature MatmulDoubleGradOpArgumentMapping( KernelSignature MatmulDoubleGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册