diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 652286ab2666e6253173f6b7d5c3751a22ee788c..efe2423b0e8a5abdcccf2db823023e1037f6ec8b 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -41,6 +41,10 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } + bool HasAttr(const std::string& name) const override { + return ctx_.HasAttr(name); + } + paddle::any Attr(const std::string& name) const override { auto& attr = ctx_.Attrs().GetAttr(name); return GetAttrValue(attr); @@ -278,21 +282,47 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, pten::InferMetaContext infer_meta_context(ctx->IsRuntime()); auto& input_names = std::get<0>(signature.args); + auto& attr_names = std::get<1>(signature.args); auto& output_names = std::get<2>(signature.args); - // TODO(chenweihang): support attrs in next pr - // auto& attr_names = std::get<1>(signature.args); - // TODO(chenweihang): support multiple inputs and outputs + // TODO(chenweihang): support multiple inputs and outputs later pten::InferMetaContext infer_mete_context; for (auto& in_name : input_names) { - infer_meta_context.EmplaceBackInput(std::make_shared( - ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); + if (ctx->HasInput(in_name)) { + infer_meta_context.EmplaceBackInput(std::make_shared( + ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); + } else { + infer_meta_context.EmplaceBackInput({nullptr}); + } } + + auto attr_reader = ctx->Attrs(); + for (auto& attr_name : attr_names) { + if (ctx->HasAttr(attr_name)) { + auto& attr = attr_reader.GetAttr(attr_name); + if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(float))) { + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + } else { + // do nothing, skip useless attrs now + // TODO(chenweihang): support other attr type later and throw error + // if attr is cannot parsed + } + } else { + // do nothing + } + } + for (auto& out_name : output_names) { - infer_meta_context.EmplaceBackOutput(std::make_shared( - ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + if (ctx->HasOutput(out_name)) { + infer_meta_context.EmplaceBackOutput(std::make_shared( + ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + } else { + infer_meta_context.EmplaceBackOutput({nullptr}); + } } - // TODO(chenweihang): support attrs later return infer_meta_context; } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index c72cbda008f3baf24c712ca3b35e68fb25f0ea06..67d60975c95d9a39676b9bb7cd7ee482c9fab3cd 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -74,6 +74,10 @@ bool InterpretercoreInferShapeContext::HasOutput( return out[0] != nullptr; } +bool InterpretercoreInferShapeContext::HasAttr(const std::string& name) const { + return op_.HasAttr(name); +} + bool InterpretercoreInferShapeContext::HasInputs( const std::string& name) const { const auto& ins = ctx_.inputs; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index b61b8af1e4a1b38f3db686e3b438aaf7745ed3c0..e00b1daf28a9ff469bfd0b81ca620161844f94b4 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -54,6 +54,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { bool HasOutput(const std::string& name) const override; + bool HasAttr(const std::string& name) const override; + bool HasInputs(const std::string& name) const override; bool HasOutputs(const std::string& name) const override; diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 7bceeb05bac599ef7150435eff5ed67b4076e846..942beb6e9a885283293b1823cc4e1b89bf1905d0 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -35,6 +35,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasOutput(const std::string &name) const override; + bool HasAttr(const std::string &name) const override; + bool HasInputs(const std::string &name) const override; bool HasOutputs(const std::string &name) const override; @@ -855,6 +857,10 @@ bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const { return block_.HasVarRecursive(output_names[0]); } +bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const { + return op_.HasAttr(name); +} + bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { if (op_.Inputs().find(name) == op_.Inputs().end()) { return false; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 392047c150dc125b3d90a4ede323e3c45e2525ce..6af5a70d66085fe226ae0d6e73fcce0276a46008 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -664,6 +664,10 @@ class RuntimeInferShapeContext : public InferShapeContext { return out[0] != nullptr; } + bool HasAttr(const std::string& name) const override { + return op_.HasAttr(name); + } + bool HasInputs(const std::string& name) const override { const auto& ins = ctx_.inputs; auto it = ins.find(name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 2294d67fbf2f32f59b89b1164823c07c7e08bd39..db529bd17f4ab8c40b555c0be83fa6cd3daa1f14 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -455,6 +455,10 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } + bool HasAttr(const std::string& name) const override { + return ctx_.HasAttr(name); + } + paddle::any Attr(const std::string& name) const override { auto& attr = ctx_.GetAttr(name); return GetAttrValue(attr); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 791600b39c3d94a911f6eae9113dc703392a7e55..09568168d8526a8352700e7edb2b2bae181eba20 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -61,6 +61,7 @@ class InferShapeContext { virtual ~InferShapeContext() = default; virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; + virtual bool HasAttr(const std::string &name) const = 0; virtual std::vector GetInputsVarType( const std::string &name) const = 0; diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 7033b9c11712dcefd49d42894fec6283eb064c9f..554657c71387b1713ec1d70526e723d2bf7a3cac 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -78,6 +78,10 @@ class DygraphInferShapeContext : public framework::InferShapeContext { return out[0] != nullptr; } + bool HasAttr(const std::string& name) const override { + return attrs_->count(name) > 0 || default_attrs_->count(name) > 0; + } + bool HasInputs(const std::string& name) const override { auto it = var_map_in_->find(name); if (it == var_map_in_->end() || it->second.empty()) { diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 5add86f5b3c74ac82ea4f4fbb0c8c1c9cb0d00f6..f7a5e2a8af409500ea7d51dc83ef32c5aad9142a 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -16,6 +16,10 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/backward.h" + namespace paddle { namespace operators { @@ -343,25 +347,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul_v2"); - OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul_v2"); - OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "matmul_v2"); - auto x_dims = context->GetInputDim("X"); - auto y_dims = context->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (context->HasOutput(x_grad_name)) { - context->SetOutputDim(x_grad_name, x_dims); - } - if (context->HasOutput(y_grad_name)) { - context->SetOutputDim(y_grad_name, y_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( @@ -539,9 +524,12 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ops::MatMulV2GradOpMaker, ops::MatMulV2GradOpMaker); +DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor, + PT_INFER_META(pten::MatmulGradInferMeta)); REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, - ops::MatMulV2OpDoubleGradMaker); + ops::MatMulV2OpDoubleGradMaker, + MatMulV2GradInferShapeFunctor); REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad, ops::MatMulV2OpTripleGradMaker, diff --git a/paddle/pten/core/compat/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h index 42ab0f1fcc2bf3a19c67bce4e0475c8ee2bb3966..c2c2b0a518d6c30b441a55ecaeb88811f9187ea8 100644 --- a/paddle/pten/core/compat/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -77,6 +77,7 @@ class ArgumentMappingContext { virtual bool HasInput(const std::string& name) const = 0; virtual bool HasOutput(const std::string& name) const = 0; + virtual bool HasAttr(const std::string& name) const = 0; // now we can't use Attribute here, it will cause pten relay on // boost::variant and BlockDesc diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index c95ae6b69f73b528c5f784d01fa2100f48971b46..6de91db9382e22537e577ce3188764034c7235e3 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -146,6 +146,7 @@ struct InferMetaFnImpl { } }; + // TODO(chenweihang): support other attr type later PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); diff --git a/paddle/pten/infermeta/backward.cc b/paddle/pten/infermeta/backward.cc index b7bb17bdd1c38b1d2616bf31c52ecfdbd8626b55..db92449519436024a01c9c891f9671756777a345 100644 --- a/paddle/pten/infermeta/backward.cc +++ b/paddle/pten/infermeta/backward.cc @@ -23,8 +23,12 @@ void MatmulGradInferMeta(const MetaTensor& x, bool transpose_y, MetaTensor* dx, MetaTensor* dy) { - dx->share_meta(x); - dy->share_meta(y); + if (dx) { + dx->share_meta(x); + } + if (dy) { + dy->share_meta(y); + } } } // namespace pten diff --git a/paddle/pten/ops/compat/matmul_sig.cc b/paddle/pten/ops/compat/matmul_sig.cc index 963d5d6656b04aa94181bdc07ab7b0cf4d92de57..7f1f2cf437a4654881b0d28bf968e5fa5cab9783 100644 --- a/paddle/pten/ops/compat/matmul_sig.cc +++ b/paddle/pten/ops/compat/matmul_sig.cc @@ -17,10 +17,17 @@ limitations under the License. */ namespace pten { KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("matmul_grad", - {"X", "Y", GradVarName("Out")}, - {"trans_x", "trans_y"}, - {GradVarName("X"), GradVarName("Y")}); + if (ctx.HasAttr("use_addto")) { + return KernelSignature("addto_matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y", "use_addto"}, + {GradVarName("X"), GradVarName("Y")}); + } else { + return KernelSignature("matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y"}, + {GradVarName("X"), GradVarName("Y")}); + } } KernelSignature MatmulDoubleGradOpArgumentMapping(