diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 6c4171a5b896aaf9c34ba62e1e2d16bd02fc5551..2164a21f3f892b2515dc77b94f7e4b91dba5dd1a 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -242,7 +242,7 @@ class AttrReader { return *attr_value; } - inline const Attribute& GetAttr(const std::string& name) const { + const Attribute* GetAttr(const std::string& name) const { auto it = attrs_.find(name); bool found = it != attrs_.end(); if (!found) { @@ -251,11 +251,10 @@ class AttrReader { found = it != default_attrs_->end(); } } - PADDLE_ENFORCE_EQ(found, true, - platform::errors::NotFound( - "Attribute (%s) should be in AttributeMap.", name)); - - return it->second; + if (found) { + return &it->second; + } + return nullptr; } private: diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 01e594a176bd0ae061102679f0f03e32ca5ac467..8a64d4e19263576e36b1162d018c106d35c3eb91 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -52,8 +52,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { } paddle::any Attr(const std::string& name) const override { - auto& attr = ctx_.Attrs().GetAttr(name); - return GetAttrValue(attr); + auto* attr = ctx_.Attrs().GetAttr(name); + PADDLE_ENFORCE_NOT_NULL( + attr, platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + return GetAttrValue(*attr); } size_t InputSize(const std::string& name) const override { @@ -450,216 +453,255 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, auto attr_reader = ctx->Attrs(); for (size_t i = 0; i < attr_names.size(); ++i) { auto& attr_name = attr_names[i]; - if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { - // When attr is a vector_tensor or tensor, transform it to IntArray - if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { - auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); - if (ctx->IsRuntime()) { - // If is in runtime, we will get tensor's value for IntArray - // and push it into attrs - std::vector vars; - vars.reserve(infershape_inputs.size()); - for (size_t i = 0; i < infershape_inputs.size(); i++) { - vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i])); + VLOG(6) << "BuildInferMetaContext: " << attr_name << ": " + << attr_defs[i].type_index; + auto* attr_ptr = attr_reader.GetAttr(attr_name); + switch (attr_defs[i].type_index) { + case phi::AttributeType::SCALAR: + if (attr_ptr) { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::FLOAT: + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(float, attr))); + break; + case framework::proto::AttrType::INT: + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(int, attr))); + break; + case framework::proto::AttrType::STRING: + infer_meta_context.EmplaceBackAttr( + phi::Scalar(BOOST_GET_CONST(std::string, attr))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "InferMetaContext.", + attr_name)); } - if (infershape_inputs.size() != 1) { - infer_meta_context.EmplaceBackAttr( - std::move(experimental::MakePhiIntArrayFromVarList(vars))); + } else if (ctx->HasInput(attr_name)) { + auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name)); + if (infershape_input.size() == 1) { + if (ctx->IsRuntime()) { + Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePhiScalarFromVar(*var))); + } else { + phi::Scalar tensor_scalar(-1); + tensor_scalar.SetFromTensor(true); + infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar)); + } } else { - infer_meta_context.EmplaceBackAttr( - std::move(experimental::MakePhiIntArrayFromVar(*vars[0]))); + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid input.size() when cast op attribute `%s` to Scalar, " + "expected 1, but actually is %d .", + attr_name, infershape_input.size())); } } else { - // If is not in runtime, we will set default value(-1) for IntArray - std::vector vars; - vars.reserve(infershape_inputs.size()); - for (size_t i = 0; i < infershape_inputs.size(); ++i) { - vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); + // do nothing, skip current attr + } + break; + case phi::AttributeType::INT_ARRAY: + // When attr is a vector_tensor or tensor, transform it to IntArray + if (attr_ptr) { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: + infer_meta_context.EmplaceBackAttr(std::move( + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::LONGS: + infer_meta_context.EmplaceBackAttr(std::move( + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::INT: + infer_meta_context.EmplaceBackAttr( + phi::IntArray({BOOST_GET_CONST(int, attr)})); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to IntArray when " + "construct InferMetaContext.", + attr_name)); } - - int64_t num_ele = 0; - if (vars.size() == 1) { - num_ele = 1; - const auto& tensor_dims = vars[0]->GetShape(); - for (size_t i = 0; i < tensor_dims.size(); ++i) { - num_ele *= tensor_dims[i]; + } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { + auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); + if (ctx->IsRuntime()) { + // If is in runtime, we will get tensor's value for IntArray + // and push it into attrs + std::vector vars; + vars.reserve(infershape_inputs.size()); + for (size_t i = 0; i < infershape_inputs.size(); i++) { + vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i])); } - - if (num_ele <= 0) { - PADDLE_THROW(platform::errors::Unimplemented( - "Invalid number for construct phi::IntArray, expected " - "number > 0, but actually is %d. ", - num_ele)); + if (infershape_inputs.size() != 1) { + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePhiIntArrayFromVarList(vars))); + } else { + infer_meta_context.EmplaceBackAttr( + std::move(experimental::MakePhiIntArrayFromVar(*vars[0]))); } - } else { - num_ele = vars.size(); + // If is not in runtime, we will set default value(-1) for IntArray + std::vector vars; + vars.reserve(infershape_inputs.size()); + for (size_t i = 0; i < infershape_inputs.size(); ++i) { + vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); + } + + int64_t num_ele = 0; + if (vars.size() == 1) { + num_ele = 1; + const auto& tensor_dims = vars[0]->GetShape(); + for (size_t i = 0; i < tensor_dims.size(); ++i) { + num_ele *= tensor_dims[i]; + } + + if (num_ele <= 0) { + PADDLE_THROW(platform::errors::Unimplemented( + "Invalid number for construct phi::IntArray, expected " + "number > 0, but actually is %d. ", + num_ele)); + } + + } else { + num_ele = vars.size(); + } + phi::IntArray tensor_attr(std::vector(num_ele, -1)); + tensor_attr.SetFromTensor(true); + infer_meta_context.EmplaceBackAttr(std::move(tensor_attr)); } - phi::IntArray tensor_attr(std::vector(num_ele, -1)); - tensor_attr.SetFromTensor(true); - infer_meta_context.EmplaceBackAttr(std::move(tensor_attr)); - } - } else if (ctx->HasAttr(attr_name)) { - auto& attr = attr_reader.GetAttr(attr_name); - if (AttrTypeID(attr) == proto::AttrType::INTS) { - infer_meta_context.EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { - infer_meta_context.EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::INT) { - infer_meta_context.EmplaceBackAttr( - phi::IntArray({BOOST_GET_CONST(int, attr)})); } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to IntArray when " - "construct InferMetaContext.", - attr_name)); + // do nothing, skip current attr } - } - } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { - if (ctx->HasAttr(attr_name)) { - // TODO(chentianyu03): support other attrs later - auto& attr = attr_reader.GetAttr(attr_name); - if (AttrTypeID(attr) == proto::AttrType::FLOAT) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(float, attr))); - } else if (AttrTypeID(attr) == proto::AttrType::STRING) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(std::string, attr))); - } else if (AttrTypeID(attr) == proto::AttrType::INT) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(int, attr))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to Scalar when construct " - "InferMetaContext.", - attr_name)); - } - } else if (ctx->HasInput(attr_name)) { - auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name)); - if (infershape_input.size() == 1) { - if (ctx->IsRuntime()) { - Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); - infer_meta_context.EmplaceBackAttr( - std::move(experimental::MakePhiScalarFromVar(*var))); - } else { - phi::Scalar tensor_scalar(-1); - tensor_scalar.SetFromTensor(true); - infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar)); + break; + case phi::AttributeType::SCALARS: + if (attr_ptr) { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::LONGS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOATS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOAT64S: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); } } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Invalid input.size() when cast op attribute `%s` to Scalar, " - "expected 1, but actually is %d .", - attr_name, infershape_input.size())); - } - } - } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { - auto& attr = attr_reader.GetAttr(attr_name); - if (AttrTypeID(attr) == proto::AttrType::INTS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to vector when " - "construct InferMetaContext.", - attr_names[i])); - } - } else if (ctx->HasAttr(attr_name)) { - // Emplace Back Attr according to the type of attr. - auto& attr = attr_reader.GetAttr(attr_name); - if (attr_defs[i].type_index == phi::AttributeType::BOOL) { - infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT32) { - infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { - infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { - infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { - infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::BOOLS) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { - if (AttrTypeID(attr) == proto::AttrType::INTS) { - // Emplace Back Attr according to the type of Phi_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - infer_meta_context.EmplaceBackAttr(vector_int64_attr); - } else { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); + // do nothing, skip current attr } - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT64S) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { - infer_meta_context.EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { - auto data_type = paddle::framework::TransToPhiDataType( - static_cast( - BOOST_GET_CONST(int, attr))); - infer_meta_context.EmplaceBackAttr(data_type); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported attribute type is received when call " - "InferShapeFunctor.")); - } - } else if (ctx->HasInput(attr_name)) { - // convert from data - if (attr_defs[i].type_index == phi::AttributeType::INT32) { - if (ctx->IsRuntime()) { - auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); - auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); - auto val = experimental::MakePhiScalarFromVar(*var_temp); - int32_t val_int = val.template to(); - infer_meta_context.EmplaceBackAttr(val_int); + break; + default: + if (attr_ptr) { + auto& attr = *attr_ptr; + switch (attr_defs[i].type_index) { + case phi::AttributeType::FLOAT32: + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + break; + case phi::AttributeType::INT32: + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); + break; + case phi::AttributeType::BOOL: + infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + break; + case phi::AttributeType::INT64: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(int64_t, attr)); + break; + case phi::AttributeType::INT32S: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::DATA_TYPE: { + auto data_type = paddle::framework::TransToPhiDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + infer_meta_context.EmplaceBackAttr(data_type); + } break; + case phi::AttributeType::STRING: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::string, attr)); + break; + case phi::AttributeType::INT64S: + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::LONGS: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case framework::proto::AttrType::INTS: { + const auto& vector_int_attr = + BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr( + vector_int_attr.begin(), vector_int_attr.end()); + infer_meta_context.EmplaceBackAttr(vector_int64_attr); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector " + "when " + "construct KernelContext.", + attr_names[i])); + } + break; + case phi::AttributeType::FLOAT32S: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::STRINGS: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::BOOLS: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::FLOAT64S: + infer_meta_context.EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` when construct " + "KernelContext in dygraph.", + attr_names[i])); + } } else { - infer_meta_context.EmplaceBackAttr(-1); + // do nothing, skip currnet attr } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Get value from variable only support int yet")); - } } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0c22321996b8f566c2dbb0efef467a41be7046ce..18287f0c7a4eec2c93b049251ddc0b0863604828 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2469,163 +2469,210 @@ void OperatorWithKernel::BuildPhiKernelContext( VLOG(4) << "Done outputs"; for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { - auto attr_iter = Attrs().find(attr_names[i]); - if (attr_iter != Attrs().end()) { // shape is in the attribute - auto& attr = attr_iter->second; - if (AttrTypeID(attr) == proto::AttrType::LONGS) { - pt_kernel_context->EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::INTS) { - pt_kernel_context->EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::INT) { - pt_kernel_context->EmplaceBackAttr( - std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to IntArray when " - "construct KernelContext.", - attr_names[i])); - } - } else { // shape is in the input - auto& ins_vector = ctx.inputs.at(attr_names[i]); - if (ins_vector.size() == 1) { // ShapeTensor + VLOG(6) << "BuildPhiKernelContext: " << attr_names[i] << ": " + << attr_defs[i].type_index; + auto attr_iter = Attrs().find(attr_names[i]); + switch (attr_defs[i].type_index) { + case phi::AttributeType::SCALAR: + if (attr_iter != Attrs().end()) { + // scalar is in the attribute + switch (AttrTypeID(attr_iter->second)) { + case proto::AttrType::FLOAT: + pt_kernel_context->EmplaceBackAttr(std::move( + phi::Scalar(BOOST_GET_CONST(float, attr_iter->second)))); + break; + case proto::AttrType::INT: + pt_kernel_context->EmplaceBackAttr(std::move( + phi::Scalar(BOOST_GET_CONST(int, attr_iter->second)))); + break; + case proto::AttrType::STRING: + pt_kernel_context->EmplaceBackAttr(std::move(phi::Scalar( + BOOST_GET_CONST(std::string, attr_iter->second)))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext in dygraph.", + attr_names[i])); + } + } else { // scalar is in the input + auto& ins_vector = ctx.inputs.at(attr_names[i]); pt_kernel_context->EmplaceBackAttr(std::move( - experimental::MakePhiIntArrayFromVar(*ins_vector.front()))); - } else { // ShapeTensorList - pt_kernel_context->EmplaceBackAttr( - std::move(experimental::MakePhiIntArrayFromVarList(ins_vector))); - } - } - } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { - auto attr_iter = Attrs().find(attr_names[i]); - if (attr_iter != Attrs().end()) { // scalar is in the attribute - auto& attr = attr_iter->second; - if (AttrTypeID(attr) == proto::AttrType::FLOAT) { - pt_kernel_context->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::STRING) { - pt_kernel_context->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else if (AttrTypeID(attr) == proto::AttrType::INT) { - pt_kernel_context->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to Scalar when construct " - "KernelContext.", - attr_names[i])); - } - } else { - auto& ins_vector = ctx.inputs.at(attr_names[i]); - pt_kernel_context->EmplaceBackAttr( - std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); - } - - } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { - auto& attr = Attrs().at(attr_names[i]); - if (AttrTypeID(attr) == proto::AttrType::INTS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); + experimental::MakePhiScalarFromVar(*ins_vector.front()))); } - pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); + break; + case phi::AttributeType::INT_ARRAY: + if (attr_iter != Attrs().end()) { + switch (AttrTypeID(attr_iter->second)) { + case proto::AttrType::INTS: + pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( + BOOST_GET_CONST(std::vector, attr_iter->second)))); + break; + case proto::AttrType::LONGS: + pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( + BOOST_GET_CONST(std::vector, attr_iter->second)))); + break; + case proto::AttrType::INT: + pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( + &BOOST_GET_CONST(int32_t, attr_iter->second), 1))); + break; + case proto::AttrType::LONG: + pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( + &BOOST_GET_CONST(int64_t, attr_iter->second), 1))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to IntArray when " + "construct KernelContext.", + attr_names[i])); + } + } else { // shape is in the input + auto& ins_vector = ctx.inputs.at(attr_names[i]); + if (ins_vector.size() == 1) { // ShapeTensor + pt_kernel_context->EmplaceBackAttr(std::move( + experimental::MakePhiIntArrayFromVar(*ins_vector.front()))); + } else { // ShapeTensorList + pt_kernel_context->EmplaceBackAttr(std::move( + experimental::MakePhiIntArrayFromVarList(ins_vector))); + } } - pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to vector when " - "construct KernelContext.", - attr_names[i])); - } - } else { - auto attr_it = attrs_.find(attr_names[i]); - if (attr_defs[i].type_index == phi::AttributeType::INT32) { - if (attr_it == attrs_.end()) { - auto in_it = ctx.inputs.find(attr_names[i]); - if (in_it != ctx.inputs.end()) { - // get data from input - auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0])); - int32_t val_int = val.template to(); - pt_kernel_context->EmplaceBackAttr(val_int); - } else { - PADDLE_THROW(platform::errors::NotFound( - "can not find attribute `%s` both in attribute and input ", + break; + case phi::AttributeType::SCALARS: { + PADDLE_ENFORCE_NE( + attr_iter, Attrs().end(), + platform::errors::NotFound("(%s) is not found in AttributeMap when " + "buildind static KernelContext.", + attr_names[i])); + switch (AttrTypeID(attr_iter->second)) { + case proto::AttrType::INTS: { + const auto& vec = + BOOST_GET_CONST(std::vector, attr_iter->second); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case proto::AttrType::LONGS: { + const auto& vec = + BOOST_GET_CONST(std::vector, attr_iter->second); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case proto::AttrType::FLOATS: { + const auto& vec = + BOOST_GET_CONST(std::vector, attr_iter->second); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case proto::AttrType::FLOAT64S: { + const auto& vec = + BOOST_GET_CONST(std::vector, attr_iter->second); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + case proto::AttrType::BOOLEANS: { + const auto& vec = + BOOST_GET_CONST(std::vector, attr_iter->second); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", attr_names[i])); - } - } else { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(int, attr_it->second)); } - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(float, attr_it->second)); - } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(bool, attr_it->second)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(int64_t, attr_it->second)); - } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(std::string, attr_it->second)); - } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { - auto data_type = paddle::framework::TransToPhiDataType( - static_cast( - BOOST_GET_CONST(int, attr_it->second))); - pt_kernel_context->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { - if (AttrTypeID(attr_it->second) == proto::AttrType::LONGS) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr_it->second)); - } else if (AttrTypeID(attr_it->second) == proto::AttrType::INTS) { - // Emplace Back Attr according to the type of Phi_Kernel args. - const auto& vector_int_attr = - BOOST_GET_CONST(std::vector, attr_it->second); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - pt_kernel_context->EmplaceBackAttr(vector_int64_attr); + } break; + default: { + PADDLE_ENFORCE_NE( + attr_iter, Attrs().end(), + platform::errors::NotFound("(%s) is not found in AttributeMap when " + "buildind static KernelContext.", + attr_names[i])); + switch (attr_defs[i].type_index) { + case phi::AttributeType::FLOAT32: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(float, attr_iter->second)); + break; + case phi::AttributeType::INT32: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(int, attr_iter->second)); + break; + case phi::AttributeType::BOOL: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(bool, attr_iter->second)); + break; + case phi::AttributeType::INT64: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(int64_t, attr_iter->second)); + break; + case phi::AttributeType::INT32S: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr_iter->second)); + break; + case phi::AttributeType::DATA_TYPE: { + auto data_type = framework::TransToPhiDataType( + static_cast( + BOOST_GET_CONST(int, attr_iter->second))); + pt_kernel_context->EmplaceBackAttr(data_type); + } break; + case phi::AttributeType::STRING: + pt_kernel_context->EmplaceBackAttr( + std::move(BOOST_GET_CONST(std::string, attr_iter->second))); + break; + case phi::AttributeType::INT64S: + switch (AttrTypeID(attr_iter->second)) { + case proto::AttrType::LONGS: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr_iter->second)); + break; + case proto::AttrType::INTS: { + const auto& vector_int_attr = + BOOST_GET_CONST(std::vector, attr_iter->second); + const std::vector vector_int64_attr( + vector_int_attr.begin(), vector_int_attr.end()); + pt_kernel_context->EmplaceBackAttr(vector_int64_attr); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector " + "when " + "construct KernelContext.", + attr_names[i])); + } + break; + case phi::AttributeType::FLOAT32S: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr_iter->second)); + break; + case phi::AttributeType::STRINGS: + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr_iter->second)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` when construct " + "KernelContext in dygraph.", + attr_names[i])); } - } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { - const auto& vector_int_attr = - BOOST_GET_CONST(std::vector, attr_it->second); - pt_kernel_context->EmplaceBackAttr(vector_int_attr); - } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr_it->second)); - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { - pt_kernel_context->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr_it->second)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` when construct " - "KernelContext.", - attr_names[i])); } } } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 9e729fee69d8642e928b7e600c1507a73bb85134..129f75e75de1e8a2072b6cf8409a12587b5e39ec 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -220,7 +220,7 @@ class PreparedOp { static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map; }; -const inline framework::Attribute& GetAttr( +const inline framework::Attribute* GetAttr( const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs, const std::string& name) { auto it = attrs.find(name); @@ -229,10 +229,10 @@ const inline framework::Attribute& GetAttr( it = default_attrs.find(name); found = it != default_attrs.end(); } - PADDLE_ENFORCE_EQ( - found, true, - platform::errors::NotFound("(%s) is not found in AttributeMap.", name)); - return it->second; + if (found) { + return &it->second; + } + return nullptr; } template @@ -330,6 +330,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, } kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); } + VLOG(6) << "BuildDygraphPhiKernelContext: Inputs parsing completed."; for (size_t i = 0; i < output_names.size(); ++i) { size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); @@ -380,178 +381,217 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, } kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } + VLOG(6) << "BuildDygraphPhiKernelContext: Outputs parsing completed."; for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { - if (attrs.find(attr_names[i]) != - attrs.end()) { // shape is in the attribute - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { - kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { + VLOG(6) << "BuildDygraphPhiKernelContext: " << attr_names[i] << ": " + << attr_defs[i].type_index; + auto* attr_ptr = GetAttr(attrs, default_attrs, attr_names[i]); + switch (attr_defs[i].type_index) { + case phi::AttributeType::SCALAR: + if (attr_ptr) { + // scalar is in the attribute + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::FLOAT: + kernel_ctx->EmplaceBackAttr( + std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); + break; + case framework::proto::AttrType::INT: + kernel_ctx->EmplaceBackAttr( + std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); + break; + case framework::proto::AttrType::STRING: + kernel_ctx->EmplaceBackAttr( + std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext in dygraph.", + attr_names[i])); + } + } else { // scalar is in the input + auto& ins_vector = ins.at(attr_names[i]); kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (AttrTypeID(attr) == framework::proto::AttrType::LONG) { - kernel_ctx->EmplaceBackAttr( - std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1))); - } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) { - kernel_ctx->EmplaceBackAttr( - std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); - } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - kernel_ctx->EmplaceBackAttr(vector_int_attr); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to VectorTensor when " - "construct KernelContext.", - attr_names[i])); + experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); } - } else { // shape is in the input - auto& ins_vector = ins.at(attr_names[i]); - if (ins_vector.size() == 1) { // ShapeTensor - kernel_ctx->EmplaceBackAttr(std::move( - experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var()))); - } else { // ShapeTensorList - std::vector variables; - variables.reserve(ins_vector.size()); - for (const auto& var_base : ins_vector) { - variables.push_back(var_base->MutableVar()); + break; + case phi::AttributeType::INT_ARRAY: + if (attr_ptr) { + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: + kernel_ctx->EmplaceBackAttr(std::move( + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::LONGS: + kernel_ctx->EmplaceBackAttr(std::move( + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + break; + case framework::proto::AttrType::INT: + kernel_ctx->EmplaceBackAttr( + std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); + break; + case framework::proto::AttrType::LONG: + kernel_ctx->EmplaceBackAttr( + std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1))); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to IntArray when " + "construct KernelContext.", + attr_names[i])); + } + } else { // shape is in the input + auto& ins_vector = ins.at(attr_names[i]); + if (ins_vector.size() == 1) { // ShapeTensor + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var()))); + } else { // ShapeTensorList + std::vector variables; + variables.reserve(ins_vector.size()); + for (const auto& var_base : ins_vector) { + variables.push_back(var_base->MutableVar()); + } + kernel_ctx->EmplaceBackAttr( + std::move(experimental::MakePhiIntArrayFromVarList(variables))); } - kernel_ctx->EmplaceBackAttr( - std::move(experimental::MakePhiIntArrayFromVarList(variables))); - } - } - } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { - // TODO(zhangyunfei): Scalar should hold scaler type, and we should check - // attribtue type by attr_defs - if (attrs.find(attr_names[i]) != attrs.end() || - default_attrs.find(attr_names[i]) != - default_attrs.end()) { // scalar is in the attribute - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) { - kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (AttrTypeID(attr) == framework::proto::AttrType::STRING) { - kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) { - kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to Scalar when construct " - "KernelContext in dygraph.", - attr_names[i])); - } - } else { // scalar is in the input - auto& ins_vector = ins.at(attr_names[i]); - kernel_ctx->EmplaceBackAttr(std::move( - experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); - } - - } else if (ins.find(attr_names[i]) != ins.end()) { - // deal tensor attr here - auto& ins_vector = ins.at(attr_names[i]); - auto tensor_attr = - experimental::MakePhiScalarFromVar(ins_vector[0]->Var()); - if (attr_defs[i].type_index == phi::AttributeType::INT32) { - int val = tensor_attr.template to(); - kernel_ctx->EmplaceBackAttr(val); - } else { - PADDLE_THROW(platform::errors::Unimplemented("only support int here")); - } - } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOATS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); - } - kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT64S) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); } - kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (AttrTypeID(attr) == framework::proto::AttrType::BOOLEANS) { - const auto& vec = BOOST_GET_CONST(std::vector, attr); - std::vector scalar_list; - scalar_list.reserve(vec.size()); - for (const auto& val : vec) { - scalar_list.emplace_back(val); + break; + case phi::AttributeType::SCALARS: { + PADDLE_ENFORCE_NOT_NULL( + attr_ptr, + platform::errors::NotFound("(%s) is not found in AttributeMap when " + "buildind dygraph KernelContext.", + attr_names[i])); + auto& attr = *attr_ptr; + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::INTS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::LONGS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOATS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::FLOAT64S: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } break; + case framework::proto::AttrType::BOOLEANS: { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); } - kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to vector when " - "construct KernelContext.", - attr_names[i])); - } - } else { - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == phi::AttributeType::INT32) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { - auto data_type = framework::TransToPhiDataType( - static_cast( - BOOST_GET_CONST(int, attr))); - kernel_ctx->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { - if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { - kernel_ctx->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { - // Emplace Back Attr according to the type of Phi_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - kernel_ctx->EmplaceBackAttr(vector_int64_attr); + } break; + default: { + PADDLE_ENFORCE_NOT_NULL( + attr_ptr, + platform::errors::NotFound("(%s) is not found in AttributeMap when " + "buildind dygraph KernelContext.", + attr_names[i])); + auto& attr = *attr_ptr; + switch (attr_defs[i].type_index) { + case phi::AttributeType::FLOAT32: + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); + break; + case phi::AttributeType::INT32: + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); + break; + case phi::AttributeType::BOOL: + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + break; + case phi::AttributeType::INT64: + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); + break; + case phi::AttributeType::INT32S: + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::DATA_TYPE: { + auto data_type = framework::TransToPhiDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + kernel_ctx->EmplaceBackAttr(data_type); + } break; + case phi::AttributeType::STRING: + kernel_ctx->EmplaceBackAttr( + std::move(BOOST_GET_CONST(std::string, attr))); + break; + case phi::AttributeType::INT64S: { + switch (AttrTypeID(attr)) { + case framework::proto::AttrType::LONGS: + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case framework::proto::AttrType::INTS: { + const auto& vector_int_attr = + BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr( + vector_int_attr.begin(), vector_int_attr.end()); + kernel_ctx->EmplaceBackAttr(vector_int64_attr); + } break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector " + "when " + "construct KernelContext.", + attr_names[i])); + } + } break; + case phi::AttributeType::FLOAT32S: + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + case phi::AttributeType::STRINGS: + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` when construct " + "KernelContext in dygraph.", + attr_names[i])); } - } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { - kernel_ctx->EmplaceBackAttr( - BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { - kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` when construct " - "KernelContext in dygraph.", - attr_names[i])); } } } + VLOG(6) << "BuildDygraphPhiKernelContext: Attributes parsing completed."; } template diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index cff14308c7fe9f63970c4c0ddb897798d100f26e..367129cd7267660e4d6c1009f13d395c3227794f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3011,7 +3011,7 @@ void UnStackInferMeta(const MetaTensor& x, } void OneHotRawInferMeta(const MetaTensor& x, - int32_t depth, + const Scalar& depth, DataType dtype, bool allow_out_of_range, MetaTensor* out) { @@ -3021,7 +3021,7 @@ void OneHotRawInferMeta(const MetaTensor& x, 1, phi::errors::InvalidArgument("Rank of Input(X) should be at least 1.")); auto out_dims_vec = phi::vectorize(x_dims); - out_dims_vec.push_back(depth); + out_dims_vec.push_back(depth.to()); auto out_dims = phi::make_ddim(out_dims_vec); out->set_dims(out_dims); out->share_lod(x); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index eef750b852f06d26ce3bd1da84a27bf2c7d963f0..97fa932eed584d8941cd9795497a6d138c1b3616 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -431,7 +431,7 @@ void UnStackInferMeta(const MetaTensor& x, std::vector outs); void OneHotRawInferMeta(const MetaTensor& x, - int32_t depth, + const Scalar& depth, DataType dtype, bool allow_out_of_range, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/one_hot_kernel.cc b/paddle/phi/kernels/cpu/one_hot_kernel.cc index 04f7c6a1f606d97f93fb2c1eb7061bac5dc8f7a8..fc7979e41d938cdc381a2821d1bf33ff5706569d 100644 --- a/paddle/phi/kernels/cpu/one_hot_kernel.cc +++ b/paddle/phi/kernels/cpu/one_hot_kernel.cc @@ -64,18 +64,19 @@ struct OneHotV2OpFunctor { template void OneHotRawKernel(const Context& dev_ctx, const DenseTensor& x, - int32_t depth, + const Scalar& depth, DataType dtype, bool allow_out_of_range, DenseTensor* out) { + auto depth_v = depth.to(); auto out_dims = out->dims(); if (out_dims[out_dims.size() - 1] == -1) { - out_dims[out_dims.size() - 1] = depth; + out_dims[out_dims.size() - 1] = depth_v; out->Resize(out_dims); } phi::VisitDataType(dtype, - OneHotV2OpFunctor(&x, out, depth, dev_ctx)); + OneHotV2OpFunctor(&x, out, depth_v, dev_ctx)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/one_hot_kernel.cu b/paddle/phi/kernels/gpu/one_hot_kernel.cu index c5884884231a85dad7ceb37db414596bb3751d35..2ae9e9333ecb568e73115570fa08db77e162d85b 100644 --- a/paddle/phi/kernels/gpu/one_hot_kernel.cu +++ b/paddle/phi/kernels/gpu/one_hot_kernel.cu @@ -73,18 +73,19 @@ struct OneHotV2OpCUDAFunctor { template void OneHotRawKernel(const Context& dev_ctx, const DenseTensor& x, - int32_t depth, + const Scalar& depth, DataType dtype, bool allow_out_of_range, DenseTensor* out) { + auto depth_v = depth.to(); auto out_dims = out->dims(); if (out_dims[out_dims.size() - 1] == -1) { - out_dims[out_dims.size() - 1] = depth; + out_dims[out_dims.size() - 1] = depth_v; out->Resize(out_dims); } phi::VisitDataType( - dtype, OneHotV2OpCUDAFunctor(&x, out, depth, dev_ctx)); + dtype, OneHotV2OpCUDAFunctor(&x, out, depth_v, dev_ctx)); } } // namespace phi diff --git a/paddle/phi/kernels/one_hot_kernel.cc b/paddle/phi/kernels/one_hot_kernel.cc index 633f48cbb62ace9e3f7f21502bd61f8c305fb542..755e06752509a4d091ad95b9c0eaefe0998fa6d9 100644 --- a/paddle/phi/kernels/one_hot_kernel.cc +++ b/paddle/phi/kernels/one_hot_kernel.cc @@ -24,9 +24,8 @@ void OneHotKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& num_classes_s, DenseTensor* out) { - int num_classes = num_classes_s.to(); OneHotRawKernel( - dev_ctx, x, num_classes, phi::DataType::FLOAT32, false, out); + dev_ctx, x, num_classes_s, phi::DataType::FLOAT32, false, out); } } // namespace phi diff --git a/paddle/phi/kernels/one_hot_kernel.h b/paddle/phi/kernels/one_hot_kernel.h index 9f89609ea63365b0e7831201ca003d6c7320c5d7..79af88473b278ee93b5c3272768a722cc2935561 100644 --- a/paddle/phi/kernels/one_hot_kernel.h +++ b/paddle/phi/kernels/one_hot_kernel.h @@ -28,7 +28,7 @@ void OneHotKernel(const Context& dev_ctx, template void OneHotRawKernel(const Context& dev_ctx, const DenseTensor& x, - int32_t depth, + const Scalar& depth, DataType dtype, bool allow_out_of_range, DenseTensor* out);