未验证 提交 2e9fd5e4 编写于 作者: F Feiyu Chan 提交者: GitHub

Add basic functionalities to support Scalar & Scalars in op attr (#51984)

Add basic functionalities to support Scalar & Scalars in operator attribute.

1. extend allowed types in operator's attribute type, add `paddle::experimental::Scalar`, add corresponding protobuf Message types;
2. Scalar enhancement, add formatting, equality;
3. add code to handle Scalar & Scalars in opmaker, conversion from  paddle operator to phi kernel, opdesc construction and manipulation,  tensorrt converter, tracer, operator construction, etc;
4. bind `paddle::experimental::Scalar` to python, as `libpaddle.Scalar`;
5. add functionality to canonicalize attribute map according to OpProto(if the op the attribute map used for has an OpProto);
6. add code to manipulate Scalar proto message via protobuffer python API;

Add unittests.

1. add test cases for formatting, equality for Scalars, and WrapAsScalars;
2. add test cases for 'casting' between different morphs of attributes;
3. add test cases for extracting scalar & scalars from attribute;
4. add test cases for CanonicalizeScalarAttrs(and fix a bug in type index offset);
5. fix gmock's library filename on windows platform.
6. clean code: use canonicalize_attrs instead of inlining the function;
7. add test cases for libpaddle.Scalar in python code.
8. add test cases for `make_scalar_proto`, which manipulate proto message `Scalar` via protobuffer python API.
上级 e91a7896
......@@ -39,7 +39,7 @@ if(WIN32)
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest_main.lib"
CACHE FILEPATH "gtest main libraries." FORCE)
set(GMOCK_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgmock.lib"
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gmock.lib"
CACHE FILEPATH "gmock libraries." FORCE)
string(REPLACE "/w " "" GTEST_CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
string(REPLACE "/w " "" GTEST_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
......
......@@ -362,7 +362,10 @@ cc_test(
program_desc_test
SRCS program_desc_test.cc
DEPS proto_desc device_context)
cc_test(
op_desc_test
SRCS op_desc_test.cc
DEPS proto_desc)
cc_library(
op_version_proto
SRCS op_version_proto.cc
......@@ -538,7 +541,8 @@ cc_library(
glog
version
xxhash
dist_attr)
dist_attr
scalar)
cc_library(
op_registry
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/attribute.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/blank.h"
namespace paddle {
......@@ -24,6 +26,8 @@ paddle::any GetAttrValue(const Attribute& attr) {
return PADDLE_GET_CONST(int, attr);
case proto::AttrType::FLOAT:
return PADDLE_GET_CONST(float, attr);
case proto::AttrType::FLOAT64:
return PADDLE_GET_CONST(double, attr);
case proto::AttrType::STRING:
return PADDLE_GET_CONST(std::string, attr);
case proto::AttrType::INTS:
......@@ -50,6 +54,10 @@ paddle::any GetAttrValue(const Attribute& attr) {
return PADDLE_GET_CONST(BlockDesc*, attr);
case proto::AttrType::BLOCKS:
return PADDLE_GET_CONST(std::vector<BlockDesc*>, attr);
case proto::AttrType::SCALAR:
return PADDLE_GET_CONST(paddle::experimental::Scalar, attr);
case proto::AttrType::SCALARS:
return PADDLE_GET_CONST(std::vector<paddle::experimental::Scalar>, attr);
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported Attribute value type `%s` for phi.",
......@@ -71,6 +79,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
case proto::AttrType::STRING: {
return attr_desc.s();
}
case proto::AttrType::FLOAT64: {
return attr_desc.float64();
}
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
......@@ -118,6 +129,18 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
return val;
}
case proto::AttrType::SCALAR: {
return make_scalar_from_proto(attr_desc.scalar());
}
case proto::AttrType::SCALARS: {
std::vector<paddle::experimental::Scalar> val(attr_desc.scalars_size());
for (int i = 0; i < attr_desc.scalars_size(); ++i) {
val[i] = make_scalar_from_proto(attr_desc.scalars(i));
}
return val;
}
default:
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported attribute type %d.", attr_desc.type()));
......@@ -147,5 +170,154 @@ Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc) {
return paddle::blank();
}
paddle::experimental::Scalar make_scalar_from_proto(const proto::Scalar& v) {
auto data_type = v.type();
switch (data_type) {
case proto::Scalar_Type_BOOLEAN:
return paddle::experimental::Scalar(v.b());
case proto::Scalar_Type_LONG:
return paddle::experimental::Scalar(v.i());
case proto::Scalar_Type_FLOAT64:
return paddle::experimental::Scalar(v.r());
case proto::Scalar_Type_COMPLEX128: {
phi::dtype::complex<double> value(v.c().r(), v.c().i());
return paddle::experimental::Scalar(value);
}
default:
PADDLE_THROW(
phi::errors::InvalidArgument("Expected scalar of type boolean, "
"integer, floating point or complex."));
break;
}
return paddle::experimental::Scalar();
}
proto::Scalar make_scalar_proto(const paddle::experimental::Scalar& v) {
proto::Scalar s;
auto data_type = v.dtype();
switch (data_type) {
case phi::DataType::BOOL:
s.set_b(v.to<bool>());
s.set_type(proto::Scalar_Type_BOOLEAN);
break;
case phi::DataType::INT8:
case phi::DataType::UINT8:
case phi::DataType::INT16:
case phi::DataType::UINT16:
case phi::DataType::INT32:
case phi::DataType::UINT32:
case phi::DataType::INT64:
case phi::DataType::UINT64:
s.set_i(v.to<int64_t>());
s.set_type(proto::Scalar_Type_LONG);
break;
case phi::DataType::FLOAT16:
case phi::DataType::BFLOAT16:
case phi::DataType::FLOAT32:
case phi::DataType::FLOAT64:
s.set_r(v.to<double>());
s.set_type(proto::Scalar_Type_FLOAT64);
break;
case phi::DataType::COMPLEX64:
case phi::DataType::COMPLEX128: {
auto value = v.to<phi::dtype::complex<double>>();
auto* complex = s.mutable_c();
complex->set_r(value.real);
complex->set_i(value.imag);
s.set_type(proto::Scalar_Type_COMPLEX128);
break;
}
case phi::DataType::UNDEFINED:
case phi::DataType::PSTRING:
case phi::DataType::NUM_DATA_TYPES:
PADDLE_THROW(
phi::errors::InvalidArgument("Expected scalar of type boolean, "
"integer, floating point or complex."));
break;
default:
PADDLE_THROW(
phi::errors::InvalidArgument("Expected scalar of type boolean, "
"integer, floating point or complex."));
break;
}
return s;
}
paddle::experimental::Scalar make_scalar_from_attribute(const Attribute& v) {
auto attr_type = static_cast<proto::AttrType>(v.index() - 1);
switch (attr_type) {
case proto::AttrType::SCALAR:
return paddle::experimental::Scalar(
PADDLE_GET_CONST(paddle::experimental::Scalar, v));
case proto::AttrType::BOOLEAN:
return paddle::experimental::Scalar(PADDLE_GET_CONST(bool, v));
case proto::AttrType::INT:
return paddle::experimental::Scalar(PADDLE_GET_CONST(int, v));
case proto::AttrType::LONG:
return paddle::experimental::Scalar(PADDLE_GET_CONST(int64_t, v));
case proto::AttrType::FLOAT:
return paddle::experimental::Scalar(PADDLE_GET_CONST(float, v));
case proto::AttrType::FLOAT64:
return paddle::experimental::Scalar(PADDLE_GET_CONST(double, v));
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Unable to construct Scalar from given Attribute of type %s",
attr_type));
}
}
std::vector<paddle::experimental::Scalar> make_scalars_from_attribute(
const Attribute& v) {
auto attr_type = static_cast<proto::AttrType>(v.index() - 1);
switch (attr_type) {
case proto::AttrType::SCALARS:
return PADDLE_GET_CONST(std::vector<paddle::experimental::Scalar>, v);
case proto::AttrType::BOOLEANS:
return experimental::WrapAsScalars(
PADDLE_GET_CONST(std::vector<bool>, v));
case proto::AttrType::INTS:
return experimental::WrapAsScalars(PADDLE_GET_CONST(std::vector<int>, v));
case proto::AttrType::LONGS:
return experimental::WrapAsScalars(
PADDLE_GET_CONST(std::vector<int64_t>, v));
case proto::AttrType::FLOATS:
return experimental::WrapAsScalars(
PADDLE_GET_CONST(std::vector<float>, v));
case proto::AttrType::FLOAT64S:
return experimental::WrapAsScalars(
PADDLE_GET_CONST(std::vector<double>, v));
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Unable to construct Scalars from given Attribute of type %s",
attr_type));
}
}
void CanonicalizeScalarAttrs(const proto::OpProto& op_proto,
AttributeMap* attrs) {
PADDLE_ENFORCE_NOT_NULL(
attrs, platform::errors::InvalidArgument("attrs can not be nullptr"));
for (auto& attr : op_proto.attrs()) {
proto::AttrType attr_type = attr.type();
const std::string& attr_name = attr.name();
auto it = attrs->find(attr_name);
if (it == attrs->end()) {
continue;
}
proto::AttrType actual_attr_type = AttrTypeID(it->second);
if (actual_attr_type == attr_type) {
continue;
}
if (actual_attr_type == proto::AttrType::VAR ||
actual_attr_type == proto::AttrType::VARS) {
continue; // VAR& VARS are not proper attribute
}
if (attr_type == proto::AttrType::SCALAR) {
it->second = Attribute(make_scalar_from_attribute(it->second));
} else if (attr_type == proto::AttrType::SCALARS) {
it->second = Attribute(make_scalars_from_attribute(it->second));
}
}
}
} // namespace framework
} // namespace paddle
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/any.h"
#include "paddle/utils/variant.h"
......@@ -244,6 +245,29 @@ struct ExtractAttribute<std::vector<double>> {
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<paddle::experimental::Scalar> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
paddle::experimental::Scalar* operator()(Attribute& attr) const {
paddle::experimental::Scalar* attr_value = nullptr;
try {
attr_value = &paddle::get<paddle::experimental::Scalar>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type Scalar, its type is %s, index is "
"%d",
attr_name_,
paddle::platform::demangle(attr.type().name()),
attr.index()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T>
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
......@@ -325,5 +349,12 @@ class AttrReader {
const AttributeMap* default_attrs_;
};
paddle::experimental::Scalar make_scalar_from_proto(const proto::Scalar& v);
proto::Scalar make_scalar_proto(const paddle::experimental::Scalar& v);
paddle::experimental::Scalar make_scalar_from_attribute(const Attribute& v);
std::vector<paddle::experimental::Scalar> make_scalars_from_attribute(
const Attribute& v);
void CanonicalizeScalarAttrs(const proto::OpProto& op_proto,
AttributeMap* attrs);
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,7 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/any.h"
TEST(Attribute, GetAttrValueToAny) {
......@@ -77,17 +78,18 @@ TEST(Attribute, GetAttrValueToAny) {
paddle::framework::Attribute var_attr(&var_desc);
auto rlt_var_attr = paddle::framework::GetAttrValue(var_attr);
auto var_desc_ptr =
paddle::any_cast<paddle::framework::VarDesc*>(rlt_var_attr);
paddle::any_cast<paddle::framework::VarDesc *>(rlt_var_attr);
EXPECT_NE(var_desc_ptr, nullptr);
EXPECT_EQ(var_desc_ptr->Name(), var_desc.Name());
paddle::framework::VarDesc var2_desc("prob");
std::vector<paddle::framework::VarDesc*> vars_desc{&var_desc, &var2_desc};
std::vector<paddle::framework::VarDesc *> vars_desc{&var_desc, &var2_desc};
paddle::framework::Attribute vars_attr(vars_desc);
auto rlt_vars_attr = paddle::framework::GetAttrValue(vars_attr);
auto rlt_vars_desc =
paddle::any_cast<std::vector<paddle::framework::VarDesc*>>(rlt_vars_attr);
paddle::any_cast<std::vector<paddle::framework::VarDesc *>>(
rlt_vars_attr);
EXPECT_EQ(rlt_vars_desc.size(), vars_desc.size());
EXPECT_EQ(rlt_vars_desc[0]->Name(), vars_desc[0]->Name());
EXPECT_EQ(rlt_vars_desc[1]->Name(), vars_desc[1]->Name());
......@@ -98,15 +100,15 @@ TEST(Attribute, GetAttrValueToAny) {
paddle::framework::Attribute x_block_desc(&block_desc);
auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc);
auto block_desc_ptr =
paddle::any_cast<paddle::framework::BlockDesc*>(rlt_block_desc);
paddle::any_cast<paddle::framework::BlockDesc *>(rlt_block_desc);
EXPECT_NE(block_desc_ptr, nullptr);
std::vector<paddle::framework::BlockDesc*> vec_block_desc_var;
std::vector<paddle::framework::BlockDesc *> vec_block_desc_var;
vec_block_desc_var.emplace_back(&block_desc);
paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var);
auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc);
auto vec_block_desc =
paddle::any_cast<std::vector<paddle::framework::BlockDesc*>>(
paddle::any_cast<std::vector<paddle::framework::BlockDesc *>>(
rlt_vec_block_desc);
EXPECT_EQ(vec_block_desc.size(), 1UL);
EXPECT_NE(vec_block_desc[0], nullptr);
......@@ -131,4 +133,235 @@ TEST(Attribute, GetAttrValueToAny) {
EXPECT_EQ(vec_double.size(), 2UL);
EXPECT_NEAR(vec_double[0], 3.14, 1e-6);
EXPECT_NEAR(vec_double[1], 3.14, 1e-6);
double x_double_val = 42.1;
paddle::framework::Attribute x_double(x_double_val);
ASSERT_EQ(AttrTypeID(x_double), paddle::framework::proto::FLOAT64);
EXPECT_NEAR(
paddle::any_cast<double>(paddle::framework::GetAttrValue(x_double)),
42.1,
1e-6);
paddle::framework::Attribute x_scalar = paddle::experimental::Scalar(42.1);
ASSERT_EQ(AttrTypeID(x_scalar), paddle::framework::proto::SCALAR);
EXPECT_EQ(paddle::any_cast<paddle::experimental::Scalar>(
paddle::framework::GetAttrValue(x_scalar)),
paddle::experimental::Scalar(42.1));
std::vector<paddle::experimental::Scalar> scalars =
paddle::experimental::WrapAsScalars(std::vector<int64_t>{1, 2, 3});
paddle::framework::Attribute x_scalars(scalars);
ASSERT_EQ(AttrTypeID(x_scalars), paddle::framework::proto::SCALARS);
auto x_extracted =
paddle::any_cast<std::vector<paddle::experimental::Scalar>>(
paddle::framework::GetAttrValue(x_scalars));
EXPECT_EQ(x_extracted.size(), 3UL);
EXPECT_EQ(x_extracted.at(0), scalars.at(0));
EXPECT_EQ(x_extracted.at(1), scalars.at(1));
EXPECT_EQ(x_extracted.at(2), scalars.at(2));
}
TEST(Attribute, ProtoAttrToAttribute_double) {
paddle::framework::proto::OpDesc::Attr proto_attr_double;
proto_attr_double.set_name("anon");
proto_attr_double.set_type(paddle::framework::proto::FLOAT64);
proto_attr_double.set_float64(42.1);
paddle::framework::Attribute attr_double =
paddle::framework::GetAttrValue(proto_attr_double);
ASSERT_EQ(AttrTypeID(attr_double), paddle::framework::proto::FLOAT64);
}
TEST(Attribute, ProtoAttrToAttribute_scalar) {
paddle::framework::proto::OpDesc::Attr proto_attr_scalar;
proto_attr_scalar.set_name("anon");
proto_attr_scalar.set_type(paddle::framework::proto::SCALAR);
auto s_bool = paddle::experimental::Scalar(static_cast<bool>(true));
auto s_int8 = paddle::experimental::Scalar(static_cast<int8_t>(42.1));
auto s_int16 = paddle::experimental::Scalar(static_cast<int16_t>(42.1));
auto s_int32 = paddle::experimental::Scalar(static_cast<int32_t>(42.1));
auto s_int64 = paddle::experimental::Scalar(static_cast<int64_t>(42.1));
auto s_uint8 = paddle::experimental::Scalar(static_cast<uint8_t>(42.1));
auto s_uint16 = paddle::experimental::Scalar(static_cast<uint16_t>(42.1));
auto s_uint32 = paddle::experimental::Scalar(static_cast<uint32_t>(42.1));
auto s_uint64 = paddle::experimental::Scalar(static_cast<uint64_t>(42.1));
auto s_float16 =
paddle::experimental::Scalar(static_cast<phi::float16>(42.1));
auto s_bfloat16 =
paddle::experimental::Scalar(static_cast<phi::bfloat16>(42.1));
auto s_float = paddle::experimental::Scalar(static_cast<float>(42.1));
auto s_double = paddle::experimental::Scalar(static_cast<double>(42.1));
auto s_cfloat = paddle::experimental::Scalar(std::complex<float>(42.1, 42.1));
auto s_cdouble =
paddle::experimental::Scalar(std::complex<double>(42.1, 42.1));
auto proto_scalar_bool = new paddle::framework::proto::Scalar;
*proto_scalar_bool = paddle::framework::make_scalar_proto(s_bool);
proto_attr_scalar.set_allocated_scalar(proto_scalar_bool);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_int8 = new paddle::framework::proto::Scalar;
*proto_scalar_int8 = paddle::framework::make_scalar_proto(s_int8);
proto_attr_scalar.set_allocated_scalar(proto_scalar_int8);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_int16 = new paddle::framework::proto::Scalar;
*proto_scalar_int16 = paddle::framework::make_scalar_proto(s_int16);
proto_attr_scalar.set_allocated_scalar(proto_scalar_int16);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_int32 = new paddle::framework::proto::Scalar;
*proto_scalar_int32 = paddle::framework::make_scalar_proto(s_int32);
proto_attr_scalar.set_allocated_scalar(proto_scalar_int32);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_int64 = new paddle::framework::proto::Scalar;
*proto_scalar_int64 = paddle::framework::make_scalar_proto(s_int64);
proto_attr_scalar.set_allocated_scalar(proto_scalar_int64);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_uint8 = new paddle::framework::proto::Scalar;
*proto_scalar_uint8 = paddle::framework::make_scalar_proto(s_uint8);
proto_attr_scalar.set_allocated_scalar(proto_scalar_uint8);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_uint16 = new paddle::framework::proto::Scalar;
*proto_scalar_uint16 = paddle::framework::make_scalar_proto(s_uint16);
proto_attr_scalar.set_allocated_scalar(proto_scalar_uint16);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_uint32 = new paddle::framework::proto::Scalar;
*proto_scalar_uint32 = paddle::framework::make_scalar_proto(s_uint32);
proto_attr_scalar.set_allocated_scalar(proto_scalar_uint32);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_uint64 = new paddle::framework::proto::Scalar;
*proto_scalar_uint64 = paddle::framework::make_scalar_proto(s_uint64);
proto_attr_scalar.set_allocated_scalar(proto_scalar_uint64);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_float16 = new paddle::framework::proto::Scalar;
*proto_scalar_float16 = paddle::framework::make_scalar_proto(s_float16);
proto_attr_scalar.set_allocated_scalar(proto_scalar_float16);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_bfloat16 = new paddle::framework::proto::Scalar;
*proto_scalar_bfloat16 = paddle::framework::make_scalar_proto(s_bfloat16);
proto_attr_scalar.set_allocated_scalar(proto_scalar_bfloat16);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_float = new paddle::framework::proto::Scalar;
*proto_scalar_float = paddle::framework::make_scalar_proto(s_float);
proto_attr_scalar.set_allocated_scalar(proto_scalar_float);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_double = new paddle::framework::proto::Scalar;
*proto_scalar_double = paddle::framework::make_scalar_proto(s_double);
proto_attr_scalar.set_allocated_scalar(proto_scalar_double);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_cfloat = new paddle::framework::proto::Scalar;
*proto_scalar_cfloat = paddle::framework::make_scalar_proto(s_cfloat);
proto_attr_scalar.set_allocated_scalar(proto_scalar_cfloat);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
auto proto_scalar_cdouble = new paddle::framework::proto::Scalar;
*proto_scalar_cdouble = paddle::framework::make_scalar_proto(s_cdouble);
proto_attr_scalar.set_allocated_scalar(proto_scalar_cdouble);
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
paddle::framework::proto::SCALAR);
}
TEST(Attribute, ProtoAttrToAttribute_scalars) {
paddle::framework::proto::OpDesc::Attr proto_attr_scalars;
proto_attr_scalars.set_name("anon");
proto_attr_scalars.set_type(paddle::framework::proto::SCALARS);
std::vector<paddle::experimental::Scalar> scalars;
for (int i = 0; i < 10; i++) {
scalars.push_back(paddle::experimental::Scalar(i));
}
std::vector<paddle::framework::proto::Scalar> proto_scalars;
proto_scalars.reserve(scalars.size());
for (const auto &item : scalars) {
proto_scalars.emplace_back(paddle::framework::make_scalar_proto(item));
}
paddle::framework::VectorToRepeated(proto_scalars,
proto_attr_scalars.mutable_scalars());
ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalars)),
paddle::framework::proto::SCALARS);
}
TEST(Attribute, make_scalar_from_attribute) {
using paddle::framework::make_scalar_from_attribute;
auto s_bool = true;
auto s_int32 = static_cast<int32_t>(42.1);
auto s_int64 = static_cast<int64_t>(42.1);
auto s_float = static_cast<float>(42.1);
auto s_double = static_cast<double>(42.1);
auto s_scalar = paddle::experimental::Scalar(42.1);
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_bool)),
paddle::experimental::Scalar(s_bool));
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_int32)),
paddle::experimental::Scalar(s_int32));
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_int64)),
paddle::experimental::Scalar(s_int64));
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_float)),
paddle::experimental::Scalar(s_float));
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_double)),
paddle::experimental::Scalar(s_double));
ASSERT_EQ(make_scalar_from_attribute(paddle::framework::Attribute(s_scalar)),
s_scalar);
}
TEST(Attribute, make_scalars_from_attribute) {
using paddle::framework::make_scalars_from_attribute;
std::vector<bool> v_bool(4, true);
std::vector<int> v_int(4, 42);
std::vector<int64_t> v_int64(4, 42);
std::vector<float> v_float(4, 42.1);
std::vector<double> v_double(4, 42.1);
std::vector<paddle::experimental::Scalar> v_scalar(
4, paddle::experimental::Scalar(std::complex<float>(42.1, 42.1)));
ASSERT_EQ(
make_scalars_from_attribute(paddle::framework::Attribute(v_bool))[0],
paddle::experimental::Scalar(v_bool[0]));
ASSERT_EQ(make_scalars_from_attribute(paddle::framework::Attribute(v_int))[0],
paddle::experimental::Scalar(v_int[0]));
ASSERT_EQ(
make_scalars_from_attribute(paddle::framework::Attribute(v_int64))[0],
paddle::experimental::Scalar(v_int64[0]));
ASSERT_EQ(
make_scalars_from_attribute(paddle::framework::Attribute(v_float))[0],
paddle::experimental::Scalar(v_float[0]));
ASSERT_EQ(
make_scalars_from_attribute(paddle::framework::Attribute(v_double))[0],
paddle::experimental::Scalar(v_double[0]));
ASSERT_EQ(
make_scalars_from_attribute(paddle::framework::Attribute(v_scalar))[0],
v_scalar[0]);
}
......@@ -39,8 +39,31 @@ enum AttrType {
VAR = 13;
VARS = 14;
FLOAT64 = 15;
SCALAR = 16;
SCALARS = 17;
}
message Complex {
required double r = 1;
required double i = 2;
};
message Scalar {
enum Type {
BOOLEAN = 1;
LONG = 2;
FLOAT64 = 3;
COMPLEX128 = 4;
}
required Type type = 1;
optional bool b = 2;
optional int64 i = 3;
optional double r = 4;
optional Complex c = 5;
};
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
......@@ -64,6 +87,8 @@ message OpDesc {
optional string var_name = 17;
repeated string vars_name = 18;
optional double float64 = 19;
optional Scalar scalar = 20;
repeated Scalar scalars = 21;
};
message Var {
......
......@@ -581,6 +581,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
case phi::AttributeType::SCALAR:
if (attr_ptr && !is_attr_var) {
auto& attr = *attr_ptr;
VLOG(6) << "type: " << AttrTypeID(attr);
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::FLOAT:
infer_meta_context.EmplaceBackAttr(
......@@ -606,6 +607,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(bool, attr)));
break;
case framework::proto::AttrType::SCALAR:
infer_meta_context.EmplaceBackAttr(phi::Scalar(
PADDLE_GET_CONST(paddle::experimental::Scalar, attr)));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
......@@ -746,6 +751,12 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::SCALARS: {
const auto& vec = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, attr);
std::vector<phi::Scalar> scalar_list{vec.begin(), vec.end()};
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/utils/blank.h"
namespace paddle {
......@@ -688,10 +689,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
if (attr_type == proto::AttrType::INTS &&
PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
auto attr_type =
is_runtime_attr
? static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1)
: GetProtoAttr(name).type();
if (is_runtime_attr) {
attr_type =
static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1);
} else if (HasProtoAttr(name)) {
attr_type = GetProtoAttr(name).type();
}
switch (attr_type) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
......@@ -761,6 +764,9 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
}
attrs_ptr->operator[](name) = v;
VLOG(10) << "op_type: " << Type() << ", attr name: " << name
<< " , type index: "
<< TransToPhiDataType(this->attrs_[name].index());
need_update_ = true;
}
......@@ -927,6 +933,11 @@ struct SetAttrDescVisitor {
void operator()(float v) const { attr_->set_f(v); }
void operator()(double v) const { attr_->set_float64(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(const paddle::experimental::Scalar &v) const {
auto *s = new proto::Scalar;
*s = make_scalar_proto(v);
attr_->set_allocated_scalar(s);
}
// Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
template <class T,
......@@ -980,6 +991,15 @@ struct SetAttrDescVisitor {
VectorToRepeated(v, attr_->mutable_float64s());
}
void operator()(const std::vector<paddle::experimental::Scalar> &v) const {
std::vector<proto::Scalar> scalars;
scalars.reserve(v.size());
for (const auto &item : v) {
scalars.emplace_back(make_scalar_proto(item));
}
VectorToRepeated(scalars, attr_->mutable_scalars());
}
void operator()(paddle::blank) const {
PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method of SetAttrDescVisitor object for "
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_desc.h"
#include <complex>
#include "gtest/gtest.h"
#include "paddle/phi/common/scalar.h"
TEST(OpDesc, SetScalarAttr) {
paddle::framework::OpDesc opdesc;
paddle::experimental::Scalar scalar(std::complex<double>(42.1, 42.1));
opdesc.SetPlainAttr("scalar", scalar);
ASSERT_EQ(opdesc.GetAttrType("scalar"), paddle::framework::proto::SCALAR);
}
TEST(OpDesc, SetScalarsAttr) {
paddle::framework::OpDesc opdesc;
paddle::experimental::Scalar scalar(std::complex<double>(42.1, 42.1));
std::vector<paddle::experimental::Scalar> scalars;
for (int i = 0; i < 4; i++) {
scalars.push_back(paddle::experimental::Scalar(i));
}
opdesc.SetPlainAttr("scalars", scalars);
ASSERT_EQ(opdesc.GetAttrType("scalars"), paddle::framework::proto::SCALARS);
opdesc.Flush();
}
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "gtest/gtest-message.h"
#include "gtest/gtest-test-part.h"
#include "gtest/gtest.h"
#include "paddle/phi/common/scalar.h"
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
......@@ -49,3 +50,34 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet);
}
class OpProtoMakerWithScalar
: public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<paddle::experimental::Scalar>("generic_scalar",
"generic_scalar of test op");
AddAttr<std::vector<paddle::experimental::Scalar>>(
"generic_vector", "generic_vector of test op");
}
};
TEST(OpProto, CanonicalizeScalarAttrs) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
OpProtoMakerWithScalar proto_maker;
proto_maker(&op_proto, &op_checker);
paddle::framework::AttributeMap amap;
amap.insert(
std::make_pair("generic_scalar", paddle::framework::Attribute(42.1)));
amap.insert(std::make_pair(
"generic_vector",
paddle::framework::Attribute(std::vector<double>{42.1, 42.2, 42.3})));
paddle::framework::CanonicalizeScalarAttrs(op_proto, &amap);
ASSERT_EQ(AttrTypeID(amap["generic_scalar"]),
paddle::framework::proto::SCALAR);
ASSERT_EQ(AttrTypeID(amap["generic_vector"]),
paddle::framework::proto::SCALARS);
}
......@@ -38,6 +38,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
}
}
auto& info = OpInfoMap::Instance().Get(type);
if (info.proto_) {
CanonicalizeScalarAttrs(*info.proto_, &standard_attrs);
}
if (attr_check) {
if (info.Checker() != nullptr) {
info.Checker()->Check(&standard_attrs);
......@@ -67,6 +70,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
auto tmp_attrs = attrs;
if (info.proto_) {
CanonicalizeScalarAttrs(*info.proto_, &tmp_attrs);
}
info.Checker()->Check(&tmp_attrs);
op_base = std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs));
......
......@@ -29,6 +29,7 @@ class OpVersion {
public:
explicit OpVersion(proto::OpVersion* desc) : desc_{desc} {}
void SetVersionID(uint32_t version) { desc_->set_version(version); }
uint32_t get() const { return desc_->version(); }
private:
proto::OpVersion* desc_;
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/none.h"
......@@ -33,19 +34,23 @@ namespace pb {
class OpVersionMap;
} // namespace pb
using OpAttrVariantT =
paddle::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
paddle::none_t /* None */
>;
using OpAttrVariantT = paddle::variant<
bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
double, /* AttrType::FLOAT64 */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
paddle::experimental::Scalar, /* AttrType::SCALAR */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<double>, /* AttrType::FLOAT64S */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
std::vector<paddle::experimental::Scalar>, /* AttrType::SCALARS*/
paddle::none_t /* None */
>;
struct OpUpdateInfo {
virtual ~OpUpdateInfo() = default;
......
......@@ -966,6 +966,11 @@ OperatorBase::OperatorBase(const std::string& type,
GenerateTemporaryNames();
CheckAllInputOutputSet();
}
// canonicalize attrs
if (info_ && info_->proto_) {
CanonicalizeScalarAttrs(*info_->proto_, &attrs_);
}
// In OperatorBase level, all attributes with VarDesc type will be considered
// as Input.
for (auto& attr : FilterAttrVar(attrs)) {
......@@ -3251,6 +3256,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
phi_kernel_context->EmplaceBackAttr(std::move(
phi::Scalar(PADDLE_GET_CONST(bool, attr_iter->second))));
break;
case proto::AttrType::SCALAR:
phi_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(
paddle::experimental::Scalar, attr_iter->second))));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
......@@ -3360,6 +3370,12 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
phi_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::SCALARS: {
const auto& vec = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, attr_iter->second);
std::vector<phi::Scalar> scalar_list{vec.begin(), vec.end()};
phi_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
#include "paddle/utils/variant.h"
......@@ -59,7 +60,9 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<double>,
VarDesc*,
std::vector<VarDesc*>,
double>;
double,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL
......
......@@ -426,6 +426,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(PADDLE_GET_CONST(bool, attr))));
break;
case framework::proto::AttrType::SCALAR:
kernel_ctx->EmplaceBackAttr(std::move(phi::Scalar(
PADDLE_GET_CONST(paddle::experimental::Scalar, attr))));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
......@@ -533,6 +537,16 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::SCALARS: {
const auto& vec = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
......
......@@ -55,17 +55,29 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::FLOAT:
return kernel_context->EmplaceBackAttr(
kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(float, attr)));
break;
case framework::proto::AttrType::FLOAT64:
kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(double, attr)));
break;
case framework::proto::AttrType::INT:
return kernel_context->EmplaceBackAttr(
kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(int, attr)));
break;
case framework::proto::AttrType::LONG:
kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(int64_t, attr)));
break;
case framework::proto::AttrType::STRING:
return kernel_context->EmplaceBackAttr(
kernel_context->EmplaceBackAttr(
phi::Scalar(PADDLE_GET_CONST(std::string, attr)));
break;
case framework::proto::AttrType::SCALAR:
kernel_context->EmplaceBackAttr(phi::Scalar(
PADDLE_GET_CONST(paddle::experimental::Scalar, attr)));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when "
......@@ -136,6 +148,16 @@ void BuildPhiKernelContextAttr(const framework::OpDesc& op_desc,
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::SCALARS: {
const auto& vec = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
......
......@@ -1308,10 +1308,18 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
} else if (type_name == "numpy.int32" || type_name == "numpy.intc") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.complex64") {
phi::dtype::complex<float> value = CastPyArg2Complex(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.complex128") {
phi::dtype::complex<double> value =
CastPyArg2Complex128(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"numpy.float16/float32/float64, numpy.int32/int64, but got %s",
"numpy.float32/float64, numpy.int32/int64, numpy.complex64/complex128, "
"but got %s",
op_type,
arg_pos + 1,
type_name)); // NOLINT
......@@ -1350,7 +1358,7 @@ paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
} else if (type_name.find("numpy") != std::string::npos) {
return CastNumpy2Scalar(obj, op_type, arg_pos);
} else if (PyComplex_Check(obj)) {
auto value = CastPyArg2Complex(obj, op_type, arg_pos);
auto value = CastPyArg2Complex128(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (PyObject_CheckLongOrToLong(&obj)) {
int value = CastPyArg2Int(obj, op_type, arg_pos);
......@@ -1405,11 +1413,19 @@ std::vector<phi::Scalar> CastPyArg2ScalarArray(PyObject* obj,
phi::Scalar{static_cast<int64_t>(PyLong_AsLong(item))});
}
return value;
} else if (PyObject_CheckComplexOrToComplex(&item)) {
std::vector<phi::Scalar> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
Py_complex v = PyComplex_AsCComplex(item);
value.emplace_back(phi::Scalar{std::complex<double>(v.real, v.imag)});
}
return value;
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"a list of int, float, or bool, but got %s",
"a list of int, float, complex, or bool, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
......
......@@ -31,6 +31,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/phi/common/complex.h"
......@@ -107,6 +108,17 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
return false;
}
bool PyObject_CheckComplexOrToComplex(PyObject** obj) {
if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) ||
PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)p_tensor_type)) { // NOLINT
return true;
}
// consider numpy cfloat & numpy cdouble?
return false;
}
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
bool CastPyArg2Boolean(PyObject* obj,
......@@ -187,6 +199,14 @@ void CastPyArg2AttrLong(PyObject* obj,
attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}
void CastPyArg2AttrScalar(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Scalar(obj, op_type, arg_pos);
}
float16 CastPyArg2Float16(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
......@@ -243,6 +263,25 @@ phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
return phi::dtype::complex<float>(0, 0);
}
phi::dtype::complex<double> CastPyArg2Complex128(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (PyComplex_Check(obj)) {
double real = PyComplex_RealAsDouble(obj);
double imag = PyComplex_ImagAsDouble(obj);
return phi::dtype::complex<double>(real, imag);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"complex, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return phi::dtype::complex<double>(0, 0);
}
void CastPyArg2AttrDouble(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......@@ -644,6 +683,14 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}
void CastPyArg2AttrScalars(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Scalars(obj, op_type, arg_pos);
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
......@@ -708,6 +755,64 @@ void CastPyArg2AttrStrings(PyObject* obj,
attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
}
std::vector<paddle::experimental::Scalar> CastPyArg2Scalars(
PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"a list of int, float, or bool, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
VLOG(4) << "type_name: " << type_name;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
item = PyList_GetItem(obj, 0);
if (PyObject_CheckFloatOrToFloat(&item)) {
std::vector<paddle::experimental::Scalar> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
value.emplace_back(
paddle::experimental::Scalar{PyFloat_AsDouble(item)});
}
return value;
} else if (PyObject_CheckLongOrToLong(&item)) {
std::vector<paddle::experimental::Scalar> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
value.emplace_back(paddle::experimental::Scalar{
static_cast<int64_t>(PyLong_AsLong(item))});
}
return value;
} else if (PyObject_CheckComplexOrToComplex(&item)) {
std::vector<paddle::experimental::Scalar> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
Py_complex v = PyComplex_AsCComplex(item);
value.emplace_back(
paddle::experimental::Scalar{std::complex<double>(v.real, v.imag)});
}
return value;
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"a list of int, float, complex, or bool, but got %s",
op_type,
arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
// Fake a ScalarArray
return std::vector<paddle::experimental::Scalar>(
{paddle::experimental::Scalar(1.0)});
}
void CastPyArg2AttrBlock(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......@@ -810,6 +915,12 @@ void ConstructAttrMapFromPyArgs(
case paddle::framework::proto::AttrType::BLOCK:
CastPyArg2AttrBlock(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::SCALAR:
CastPyArg2AttrScalar(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::SCALARS:
CastPyArg2AttrScalars(obj, attrs, key, op_type, arg_pos);
break;
default:
break;
}
......
......@@ -46,6 +46,8 @@ bool PyObject_CheckLongOrToLong(PyObject** obj);
bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool PyObject_CheckComplexOrToComplex(PyObject** obj);
bool PyObject_CheckString(PyObject* obj);
bool CastPyArg2Boolean(PyObject* obj,
......@@ -67,6 +69,9 @@ double CastPyArg2Double(PyObject* obj,
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
phi::dtype::complex<double> CastPyArg2Complex128(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
......@@ -89,6 +94,9 @@ std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::vector<paddle::experimental::Scalar> CastPyArg2Scalars(
PyObject* obj, const std::string& op_type, ssize_t arg_pos);
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......@@ -125,6 +133,12 @@ void CastPyArg2AttrString(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrScalar(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrBooleans(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......@@ -155,6 +169,12 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrScalars(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrStrings(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/protobuf.h"
#include <complex>
#include <deque>
#include <iostream>
#include <string>
......@@ -27,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/jit/property.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
#include "paddle/phi/common/scalar.h"
namespace py = pybind11;
......@@ -287,7 +289,7 @@ void BindOpDesc(pybind11::module *m) {
.value("LONGS", pd::proto::AttrType::LONGS)
.value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS)
// .value("FLOAT64", pd::proto::AttrType::FLOAT64)
.value("FLOAT64", pd::proto::AttrType::FLOAT64)
.value("FLOAT64S", pd::proto::AttrType::FLOAT64S)
.value("STRING", pd::proto::AttrType::STRING)
.value("STRINGS", pd::proto::AttrType::STRINGS)
......@@ -296,7 +298,9 @@ void BindOpDesc(pybind11::module *m) {
.value("BLOCK", pd::proto::AttrType::BLOCK)
.value("BLOCKS", pd::proto::AttrType::BLOCKS)
.value("VAR", pd::proto::AttrType::VAR)
.value("VARS", pd::proto::AttrType::VARS);
.value("VARS", pd::proto::AttrType::VARS)
.value("SCALAR", pd::proto::AttrType::SCALAR)
.value("SCALARS", pd::proto::AttrType::SCALARS);
pybind11::class_<pd::OpDesc> op_desc(*m, "OpDesc", "");
op_desc
......@@ -379,6 +383,11 @@ void BindOpDesc(pybind11::module *m) {
.def("_set_strs_attr",
&pd::OpDesc::SetPlainAttr<std::vector<std::string>>)
.def("_set_scalar_attr",
&pd::OpDesc::SetPlainAttr<paddle::experimental::Scalar>)
.def("_set_scalars_attr",
&pd::OpDesc::SetPlainAttr<std::vector<paddle::experimental::Scalar>>)
.def(
"attr",
[](pd::OpDesc &self, const std::string &name, bool with_attr_var) {
......@@ -417,6 +426,44 @@ void BindOpDesc(pybind11::module *m) {
pybind11::return_value_policy::reference)
.def("inputs", [](pd::OpDesc &self) { return self.Inputs(); })
.def("outputs", &pd::OpDesc::Outputs);
pybind11::class_<paddle::experimental::Scalar> scalar(*m, "Scalar", "");
scalar.def(py::init<bool>())
.def(py::init<double>())
.def(py::init<int64_t>())
.def(py::init<std::complex<double>>())
.def(py::init<paddle::experimental::Scalar>())
.def("__str__", &paddle::experimental::Scalar::ToString)
.def("__repr__", &paddle::experimental::Scalar::ToString)
.def("__eq__", &paddle::experimental::Scalar::operator==<paddle::Tensor>)
.def("value",
[](const paddle::experimental::Scalar &self)
-> paddle::variant<bool, int64_t, double, std::complex<double>> {
auto dtype = self.dtype();
switch (dtype) {
case phi::DataType::FLOAT64:
case phi::DataType::FLOAT32:
return self.to<double>();
case phi::DataType::INT32:
case phi::DataType::INT64:
return self.to<int64_t>();
case phi::DataType::BOOL:
return self.to<bool>();
case phi::DataType::COMPLEX64:
case phi::DataType::COMPLEX128:
// to paddle's complex to avoid ambiguious
// when converting bfloat16 or float16 to std::copmplex<double>
return static_cast<std::complex<double>>(
self.to<phi::dtype::complex<double>>());
default:
PD_THROW("Invalid tensor data type `", dtype, "`.");
}
});
pybind11::implicitly_convertible<bool, paddle::experimental::Scalar>();
pybind11::implicitly_convertible<double, paddle::experimental::Scalar>();
pybind11::implicitly_convertible<int64_t, paddle::experimental::Scalar>();
pybind11::implicitly_convertible<std::complex<double>,
paddle::experimental::Scalar>();
}
// Serialize Class Property
......
......@@ -44,5 +44,15 @@ ScalarBase<phi::DenseTensor>::ScalarBase(const phi::DenseTensor& tensor_in)
}
}
bool operator==(const Scalar& lhs, const Scalar& rhs) {
return lhs.operator==(rhs);
}
bool operator!=(const Scalar& lhs, const Scalar& rhs) {
return lhs.operator!=(rhs);
}
std::ostream& operator<<(std::ostream& os, const Scalar& s) {
return os << s.ToString();
}
} // namespace experimental
} // namespace paddle
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <cstdint>
#include <limits>
#include <sstream>
#include <vector>
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/data_type.h"
......@@ -85,10 +87,19 @@ class ScalarBase {
data_.c64 = val;
}
ScalarBase(std::complex<float> val) : dtype_(DataType::COMPLEX64) { // NOLINT
data_.c64 = val;
}
ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT
data_.c128 = val;
}
ScalarBase(std::complex<double> val) // NOLINT
: dtype_(DataType::COMPLEX128) {
data_.c128 = val;
}
// The compatible method for fliud operators,
// and it will be removed in the future.
explicit ScalarBase(const std::string& str_value)
......@@ -100,7 +111,10 @@ class ScalarBase {
} else if (str_value == "nan") {
data_.f64 = std::numeric_limits<double>::quiet_NaN();
} else {
data_.f64 = std::stod(str_value);
// NOTE(chenfeiyu): to support subnormal floating point number
// std::stod cannot handle subnormal values
std::istringstream ss(str_value);
ss >> data_.f64;
}
}
......@@ -120,6 +134,7 @@ class ScalarBase {
template <typename RT>
inline RT to() const {
// TODO(chenfeiyu): warn on non-lossless cast.
switch (dtype_) {
case DataType::FLOAT32:
return static_cast<RT>(data_.f32);
......@@ -137,6 +152,10 @@ class ScalarBase {
return static_cast<RT>(data_.i16);
case DataType::INT8:
return static_cast<RT>(data_.i8);
case DataType::UINT64:
return static_cast<RT>(data_.ui64);
case DataType::UINT32:
return static_cast<RT>(data_.ui32);
case DataType::UINT16:
return static_cast<RT>(data_.ui16);
case DataType::UINT8:
......@@ -154,6 +173,107 @@ class ScalarBase {
DataType dtype() const { return dtype_; }
template <typename T2>
bool operator==(const ScalarBase<T2>& other) const {
DataType data_type = this->dtype();
if (data_type != other.dtype()) {
return false;
}
switch (data_type) {
case DataType::BOOL:
return this->data_.b == other.data_.b;
case DataType::INT8:
return this->data_.i8 == other.data_.i8;
case DataType::UINT8:
return this->data_.ui8 == other.data_.ui8;
case DataType::INT16:
return this->data_.i16 == other.data_.i16;
case DataType::UINT16:
return this->data_.ui16 == other.data_.ui16;
case DataType::INT32:
return this->data_.i32 == other.data_.i32;
case DataType::UINT32:
return this->data_.ui32 == other.data_.ui32;
case DataType::INT64:
return this->data_.i64 == other.data_.i64;
case DataType::UINT64:
return this->data_.ui64 == other.data_.ui64;
case DataType::FLOAT16:
return this->data_.f16 == other.data_.f16;
case DataType::BFLOAT16:
return this->data_.bf16 == other.data_.bf16;
case DataType::FLOAT32:
return this->data_.f32 == other.data_.f32;
case DataType::FLOAT64:
return this->data_.f64 == other.data_.f64;
case DataType::COMPLEX64:
return this->data_.c64 == other.data_.c64;
case DataType::COMPLEX128:
return this->data_.c128 == other.data_.c128;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}
template <typename T2>
bool operator!=(const ScalarBase<T2>& other) const {
return !operator==(other);
}
std::string ToRawString() const {
std::stringstream ss;
switch (dtype_) {
case DataType::FLOAT32:
ss << data_.f32;
break;
case DataType::FLOAT64:
ss << data_.f64;
break;
case DataType::FLOAT16:
ss << data_.f16;
break;
case DataType::BFLOAT16:
ss << data_.bf16;
break;
case DataType::INT32:
ss << data_.i32;
break;
case DataType::INT64:
ss << data_.i64;
break;
case DataType::INT16:
ss << data_.i16;
break;
case DataType::INT8:
ss << data_.i8;
break;
case DataType::UINT16:
ss << data_.ui16;
break;
case DataType::UINT8:
ss << data_.ui8;
break;
case DataType::BOOL:
ss << data_.b;
break;
case DataType::COMPLEX64:
ss << data_.c64;
break;
case DataType::COMPLEX128:
ss << data_.c128;
break;
default:
break;
}
return ss.str();
}
std::string ToString() const {
std::stringstream ss;
ss << "Scalar(" << dtype_ << '(' << ToRawString() << "))";
return ss.str();
}
private:
template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
......@@ -230,7 +350,31 @@ void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
}
using Scalar = paddle::experimental::ScalarBase<Tensor>;
bool operator==(const Scalar& lhs, const Scalar& rhs);
std::ostream& operator<<(std::ostream& os, const Scalar& s);
template <typename T>
std::vector<T> ExtractPlainVector(
const std::vector<paddle::experimental::Scalar>& values) {
std::vector<T> results;
results.reserve(values.size());
for (const auto& item : values) {
results.push_back(item.to<T>());
}
return results;
}
template <typename T>
std::vector<paddle::experimental::Scalar> WrapAsScalars(
const std::vector<T>& values) {
std::vector<paddle::experimental::Scalar> results;
results.reserve(values.size());
for (const auto& item : values) {
results.push_back(paddle::experimental::Scalar(item));
}
return results;
}
} // namespace experimental
} // namespace paddle
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/blank.h"
DECLARE_int32(call_stack_level);
......@@ -46,7 +47,9 @@ using Attribute = paddle::variant<paddle::blank,
std::vector<double>,
VarDesc*,
std::vector<VarDesc*>,
double>;
double,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
} // namespace framework
namespace imperative {
......
......@@ -13,6 +13,9 @@
# limitations under the License.
from collections import defaultdict
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
from paddle import _C_ops, _legacy_C_ops
......@@ -321,6 +324,14 @@ class Tracer(core.Tracer):
type, inputs, outputs, attrs, stop_gradient, inplace_map
)
else:
# this block is used to convert attrs according to the opproto
# since `trace` handles AttributeMap directly, without other
# modification to the passed attribute map, so we change it before
# `trace`
if framework.OpProtoHolder.instance().has_op_proto(type):
proto = framework.OpProtoHolder.instance().get_op_proto(type)
attrs = framework.canonicalize_attrs(attrs, proto)
self.trace(
type,
inputs,
......
......@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap
import collections
from collections import defaultdict
from collections.abc import Iterable
......@@ -30,7 +32,6 @@ import logging
from .proto import framework_pb2, data_feed_pb2
from . import core
from . import unique_name
import paddle.version as fluid_version
......@@ -72,6 +73,7 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
# use thread local to create thread save global variables.
class GlobalThreadLocal(threading.local):
def __init__(self):
......@@ -126,7 +128,6 @@ _cuda_graph_enable_standalone_executor_ = os.environ.get(
'FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR', 0
)
# special_op_attrs, extra_op_attrs are prepared for printing warnings
# when turning on FLAGS_print_extra_attrs
special_op_attrs = {
......@@ -169,7 +170,6 @@ extra_op_attrs = {
"unique": ["is_sorted"],
}
# Some explanation of our execution system 2022.03
# For now we have 3 kinds of execution system, since we refactored dygraph mode to
# build a fast execution system for dynamic mode. But we can't just remove all legacy
......@@ -1424,6 +1424,101 @@ def _all_is_type(vals, expected_type):
return all(isinstance(v, expected_type) for v in vals)
def wrap_as_scalar(number):
"""Wrap a number(either python scalar or numpy scalar) as core.Scalar if
it is not a scalar.
Args:
number (Number): number
Returns:
Scalar: A Scalar that contains the value.
"""
if isinstance(number, core.Scalar):
return number
if isinstance(number, (bool, int, float, complex)):
return core.Scalar(number)
if isinstance(number, np.number):
# it is a numpy scalar
return core.Scalar(number.item())
else:
raise TypeError("Cannot wrap {} as core.Scalar".format(number))
def wrap_as_scalars(array):
"""This function is used to convert flat list, or numpy array(not
necesarily flat) to list of core.Scalar, which correspond to
std::vector<paddle::experimental::Scalar> in operator runtime.
Args:
array (List | np.ndarray): array of numbers
Returns:
List: list of core.Scalar, of which each element is a Scalar containing
the corresponding value.
"""
if isinstance(array, np.ndarray):
array = array.ravel().tolist()
return [wrap_as_scalar(item) for item in array]
def extract_plain_list(array):
"""extract value from a list of core.Scalar.
Args:
array (list): Scalars
Returns:
list: values extracted from the scalars.
"""
return [item.value() for item in array]
def canonicalize_attrs(attrs, op_proto):
"""This function is used to canonicalize attributes(as a string->any dict)
according to the type specification in the OpProto. This is especially
important for operators that has any attributes of type Scalar or Scalars.
Though various frontends of phi kernels & paddle operators can wrap variables
of concrete types into Scalars(a tagged union of several numeric types) or
vector of Scalars. Paddle operator requires strict type matching.
Args:
attrs (Dict[str, Any]): attribute dict intended to pass to an operator.
op_proto (OpProto): Proto (signature) of the operator.
Returns:
Dict[str, Any]: canonicalized attributes.
"""
canonicalized_attrs = attrs.copy() # shallow copy is enough here
for attr in op_proto.attrs:
attr_name = attr.name
type_index = attr.type
if (attr_name not in attrs) or (attrs[attr_name] is None):
continue
attr_val = attrs[attr_name]
# VAR and VARS should be skipped
if isinstance(attr_val, Variable):
continue
if isinstance(attr_val, list) and _all_is_type(attr_val, Variable):
continue
# wrap
if type_index == core.AttrType.SCALAR:
canonicalized_attrs[attr_name] = core.Scalar(attr_val)
elif type_index == core.AttrType.SCALARS:
# it should be a list (or a numpy array)
if len(attr_val) > 0:
attr_val = np.array(attr_val).ravel().tolist()
attr_val = [core.Scalar(x) for x in attr_val]
canonicalized_attrs[attr_name] = attr_val
return canonicalized_attrs
class VariableMetaClass(type):
@classmethod
def __instancecheck__(cls, instance):
......@@ -3506,7 +3601,12 @@ class Operator:
return
type_index = self._attr_types[name]
if type_index == core.AttrType.BOOL:
# if the required attribute is a SCALAR, pass as-is
if type_index == core.AttrType.SCALAR:
desc._set_scalar_attr(name, wrap_as_scalar(val))
elif type_index == core.AttrType.SCALARS:
desc._set_scalars_attr(name, wrap_as_scalars(val))
elif type_index == core.AttrType.BOOL:
desc._set_bool_attr(name, val)
elif type_index == core.AttrType.INT:
desc._set_int32_attr(name, val)
......@@ -3514,8 +3614,8 @@ class Operator:
desc._set_int64_attr(name, val)
elif type_index == core.AttrType.FLOAT:
desc._set_float32_attr(name, val)
# elif type_index == core.AttrType.FLOAT64:
# desc._set_float64_attr(name, val)
elif type_index == core.AttrType.FLOAT64:
desc._set_float64_attr(name, val)
elif type_index == core.AttrType.STRING:
desc._set_str_attr(name, val)
elif type_index == core.AttrType.BOOLS:
......
......@@ -17,6 +17,27 @@ import numpy as np
import paddle.fluid.core as core
import paddle.fluid.proto.framework_pb2 as framework_pb2
# NOTE: this is added to support creating a Scalar message
# from a python number
def make_scalar_proto(value):
s = framework_pb2.Scalar()
if isinstance(value, bool):
s.type = framework_pb2.Scalar.Type.BOOLEAN
s.b = value
elif isinstance(value, int):
s.type = framework_pb2.Scalar.Type.LONG
s.i = value
elif isinstance(value, float):
s.type = framework_pb2.Scalar.Type.FLOAT64
s.r = value
elif isinstance(value, complex):
s.type = framework_pb2.Scalar.Type.COMPLEX128
complex_value = framework_pb2.Complex()
complex_value.r = value.real
complex_value.i = value.imag
s.c.CopyFrom(complex_value)
return s
def get_all_op_protos():
"""
......@@ -127,6 +148,18 @@ class OpDescCreationMethod:
new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOAT64:
new_attr.float64 = user_defined_attr
elif attr.type == framework_pb2.FLOAT64S:
new_attr.float64s.extend(user_defined_attr)
# the code below manipulates protobuf directly
elif attr.type == framework_pb2.SCALAR:
scalar = make_scalar_proto(user_defined_attr)
new_attr.scalar.CopyFrom(scalar)
elif attr.type == framework_pb2.SCALARS:
scalars = [
make_scalar_proto(item) for item in user_defined_attr
]
for item in scalars:
new_attr.scalars.MergeFrom(item)
else:
raise NotImplementedError(
"A not supported attribute type: %s." % (str(attr.type))
......@@ -162,6 +195,20 @@ class OpDescCreationMethod:
new_attr.bools.extend(user_defined_attr)
elif attr_type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOAT64:
new_attr.float64 = user_defined_attr
elif attr.type == framework_pb2.FLOAT64S:
new_attr.float64s.extend(user_defined_attr)
# the code below manipulates protobuf directly
elif attr.type == framework_pb2.SCALAR:
scalar = make_scalar_proto(user_defined_attr)
new_attr.scalar.CopyFrom(scalar)
elif attr.type == framework_pb2.SCALARS:
scalars = [
make_scalar_proto(item) for item in user_defined_attr
]
for item in scalars:
new_attr.scalars.MergeFrom(item)
else:
raise NotImplementedError(
"A not supported attribute type: %s." % (str(attr_type))
......
......@@ -33,6 +33,7 @@ from paddle.fluid.framework import (
OpProtoHolder,
Program,
_current_expected_place,
canonicalize_attrs,
)
from paddle.fluid.op import Operator
......@@ -969,7 +970,7 @@ class OpTest(unittest.TestCase):
self.op_type,
dygraph_tensor_inputs,
dygraph_tensor_outputs,
attrs_outputs,
canonicalize_attrs(attrs_outputs, op_proto),
)
if not kernel_sig or (
len(kernel_sig[0]) == 0
......
......@@ -19,12 +19,13 @@ import numpy as np
import paddle
from paddle import fluid
from paddle.fluid import core
from paddle.fluid import core, framework
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import (
IrGraph,
IrNode,
Operator,
OpProtoHolder,
convert_np_dtype_to_dtype_,
)
from paddle.static.quantization import (
......@@ -182,7 +183,15 @@ class BlockConfig:
op_desc.set_type(op_config.type)
for name, values in op_config.inputs.items():
op_desc.set_input(name, values)
for name, values in op_config.attrs.items():
# canonicalize scalar attrs
if OpProtoHolder.instance().has_op_proto(op_config.type):
proto = OpProtoHolder.instance().get_op_proto(op_config.type)
canonicalized_attrs = framework.canonicalize_attrs(
op_config.attrs, proto
)
else:
canonicalized_attrs = op_config.attrs
for name, values in canonicalized_attrs.items():
op_desc._set_attr(name, values)
for name, values in op_config.outputs.items():
op_desc.set_output(name, values)
......@@ -323,9 +332,18 @@ def create_fake_model(program_config):
for op_config in program_config.ops:
op_desc = main_block_desc.append_op()
op_desc.set_type(op_config.type)
# canonicalize scalar attrs
if OpProtoHolder.instance().has_op_proto(op_config.type):
proto = OpProtoHolder.instance().get_op_proto(op_config.type)
canonicalized_attrs = framework.canonicalize_attrs(
op_config.attrs, proto
)
else:
canonicalized_attrs = op_config.attrs
for name, values in op_config.inputs.items():
op_desc.set_input(name, values)
for name, values in op_config.attrs.items():
for name, values in canonicalized_attrs.items():
if name == 'sub_block':
sub_block_desc = main_program_desc.append_block(main_block_desc)
values.fill_block_desc(sub_block_desc)
......
......@@ -37,6 +37,7 @@ from paddle.fluid.framework import (
_enable_legacy_dygraph,
_in_eager_without_dygraph_check,
_test_eager_guard,
canonicalize_attrs,
in_dygraph_mode,
)
from paddle.fluid.op import Operator
......@@ -964,7 +965,7 @@ class OpTest(unittest.TestCase):
self.op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
canonicalize_attrs(attrs_outputs, op_proto),
)
if not kernel_sig:
return None
......
......@@ -21,7 +21,12 @@ import numpy as np
import paddle
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.fluid.framework import (
OpProtoHolder,
_dygraph_tracer,
canonicalize_attrs,
in_dygraph_mode,
)
from paddle.incubate.autograd import primapi
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
......@@ -61,12 +66,17 @@ class OpTestUtils:
def _get_kernel_signature(
cls, op_type, eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
):
try:
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
canonicalized_attrs = canonicalize_attrs(attrs_outputs, op_proto)
except ValueError:
canonicalized_attrs = attrs_outputs
try:
kernel_sig = _dygraph_tracer()._get_kernel_signature(
op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
canonicalized_attrs,
)
except RuntimeError as re:
"""we think the kernel_sig is missing."""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid import framework, op
class TestWarpAsScalar(unittest.TestCase):
def test_for_int(self):
s = framework.wrap_as_scalar(np.iinfo(np.int64).max)
self.assertEqual(s, np.iinfo(np.int64).max)
def test_for_float(self):
maximum = float(np.finfo(np.float64).max)
s = framework.wrap_as_scalar(maximum)
self.assertEqual(s, maximum)
def test_for_bool(self):
s = framework.wrap_as_scalar(True)
self.assertEqual(s, True)
def test_for_complex(self):
c = 42.1 + 42.1j
s = framework.wrap_as_scalar(c)
self.assertEqual(s, c)
def test_for_numpy_scalar(self):
maximum = np.finfo(np.float64).max
s = framework.wrap_as_scalar(maximum)
self.assertEqual(s, maximum)
def test_for_scalar(self):
s1 = framework.wrap_as_scalar(42)
s2 = framework.wrap_as_scalar(s1)
self.assertEqual(s2, s1)
def test_for_exception(self):
with self.assertRaises(TypeError):
framework.wrap_as_scalar("abc")
class TestWarpAsScalars(unittest.TestCase):
def test_rewrap(self):
vec = [framework.wrap_as_scalar(item) for item in (1, 2, 3, 4)]
vec2 = framework.wrap_as_scalars(vec)
self.assertListEqual(vec, vec2)
def test_numpy_array(self):
arr = np.random.randn(2, 3).astype(np.float64)
scalars = framework.wrap_as_scalars(arr)
values = framework.extract_plain_list(scalars)
self.assertListEqual(arr.ravel().tolist(), values)
def test_numeric_list(self):
arr = [1 + 2j, 3 + 4j]
scalars = framework.wrap_as_scalars(arr)
values = framework.extract_plain_list(scalars)
self.assertListEqual(arr, values)
class TestScalarValue(unittest.TestCase):
def test_for_int(self):
s = framework.wrap_as_scalar(np.iinfo(np.int64).max)
self.assertEqual(s.value(), np.iinfo(np.int64).max)
def test_for_float(self):
maximum = float(np.finfo(np.float64).max)
s = framework.wrap_as_scalar(maximum)
self.assertEqual(s.value(), maximum)
def test_for_bool(self):
s = framework.wrap_as_scalar(True)
self.assertEqual(s.value(), True)
def test_for_complex(self):
c = 42.1 + 42.1j
s = framework.wrap_as_scalar(c)
self.assertEqual(s.value(), c)
def test_for_numpy_scalar(self):
maximum = np.finfo(np.float64).max
s = framework.wrap_as_scalar(maximum)
self.assertEqual(s.value(), float(maximum))
def test_for_scalar(self):
s1 = framework.wrap_as_scalar(42)
s2 = framework.wrap_as_scalar(s1)
self.assertEqual(s2.value(), s1.value())
class TestScalarProto(unittest.TestCase):
def test_make_scalar_proto_for_int(self):
s = op.make_scalar_proto(42)
self.assertEqual(s.i, 42)
def test_make_scalar_proto_for_float(self):
s = op.make_scalar_proto(42.1)
self.assertEqual(s.r, 42.1)
def test_make_scalar_proto_for_bool(self):
s = op.make_scalar_proto(True)
self.assertEqual(s.b, True)
def test_make_scalar_proto_for_complex(self):
s = op.make_scalar_proto(42.1 + 42.2j)
self.assertEqual(s.c.r, 42.1)
self.assertEqual(s.c.i, 42.2)
......@@ -18,6 +18,10 @@ cc_test(
phi_test_int_array
SRCS test_int_array.cc
DEPS int_array api_int_array phi phi_api)
cc_test(
phi_test_scalar_cpu
SRCS test_scalar.cc
DEPS scalar api_scalar)
if(WITH_GPU)
nv_test(
phi_test_scalar
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <complex>
#include <sstream>
#include <string>
#include "gtest/gtest.h"
#include "paddle/phi/common/scalar.h"
namespace phi {
namespace tests {
bool StartsWith(const std::string& s, const std::string& prefix) {
return s.rfind(prefix, 0) == 0;
}
TEST(Scalar, Formating) {
paddle::experimental::Scalar s;
s = paddle::experimental::Scalar(static_cast<float>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(float32(");
s = paddle::experimental::Scalar(static_cast<double>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(float64(");
s = paddle::experimental::Scalar(static_cast<int>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(int32(");
s = paddle::experimental::Scalar(static_cast<int64_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(int64(");
s = paddle::experimental::Scalar(static_cast<bool>(true));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(bool(");
s = paddle::experimental::Scalar(std::complex<float>(42.1, 42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(complex64(");
s = paddle::experimental::Scalar(std::complex<double>(42.1, 42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(complex128(");
s = paddle::experimental::Scalar(static_cast<phi::float16>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(float16(");
s = paddle::experimental::Scalar(static_cast<phi::bfloat16>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(bfloat16(");
s = paddle::experimental::Scalar(static_cast<int8_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(int8(");
s = paddle::experimental::Scalar(static_cast<int16_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(int16(");
s = paddle::experimental::Scalar(static_cast<uint8_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(uint8(");
s = paddle::experimental::Scalar(static_cast<uint16_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(uint16(");
s = paddle::experimental::Scalar(static_cast<uint32_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(uint32(");
s = paddle::experimental::Scalar(static_cast<uint64_t>(42.1));
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(uint64(");
std::stringstream ss;
s = paddle::experimental::Scalar(static_cast<uint64_t>(42.1));
ss << s;
ASSERT_PRED2(StartsWith, s.ToString(), "Scalar(uint64(");
}
TEST(Scalar, Equality) {
auto s_bool = paddle::experimental::Scalar(static_cast<bool>(true));
auto s_int8 = paddle::experimental::Scalar(static_cast<int8_t>(42.1));
auto s_int16 = paddle::experimental::Scalar(static_cast<int16_t>(42.1));
auto s_int32 = paddle::experimental::Scalar(static_cast<int32_t>(42.1));
auto s_int64 = paddle::experimental::Scalar(static_cast<int64_t>(42.1));
auto s_uint8 = paddle::experimental::Scalar(static_cast<uint8_t>(42.1));
auto s_uint16 = paddle::experimental::Scalar(static_cast<uint16_t>(42.1));
auto s_uint32 = paddle::experimental::Scalar(static_cast<uint32_t>(42.1));
auto s_uint64 = paddle::experimental::Scalar(static_cast<uint64_t>(42.1));
auto s_float16 =
paddle::experimental::Scalar(static_cast<phi::float16>(42.1));
auto s_bfloat16 =
paddle::experimental::Scalar(static_cast<phi::bfloat16>(42.1));
auto s_float = paddle::experimental::Scalar(static_cast<float>(42.1));
auto s_double = paddle::experimental::Scalar(static_cast<double>(42.1));
auto s_cfloat = paddle::experimental::Scalar(std::complex<float>(42.1, 42.1));
auto s_cdouble =
paddle::experimental::Scalar(std::complex<double>(42.1, 42.1));
ASSERT_EQ(s_bool, s_bool);
ASSERT_EQ(s_int8, s_int8);
ASSERT_EQ(s_int16, s_int16);
ASSERT_EQ(s_int32, s_int32);
ASSERT_EQ(s_int64, s_int64);
ASSERT_EQ(s_uint8, s_uint8);
ASSERT_EQ(s_uint16, s_uint16);
ASSERT_EQ(s_uint32, s_uint32);
ASSERT_EQ(s_uint64, s_uint64);
ASSERT_EQ(s_float16, s_float16);
ASSERT_EQ(s_bfloat16, s_bfloat16);
ASSERT_EQ(s_float, s_float);
ASSERT_EQ(s_double, s_double);
ASSERT_EQ(s_cfloat, s_cfloat);
ASSERT_EQ(s_cdouble, s_cdouble);
ASSERT_NE(s_float, s_double);
}
TEST(Scalar, WrapAsScalars) {
std::vector<int32_t> v{1, 2, 3};
auto out = paddle::experimental::WrapAsScalars(v);
ASSERT_EQ(out[0].dtype(), phi::DataType::INT32);
ASSERT_EQ(out[0].to<int32_t>(), 1);
ASSERT_EQ(out[1].to<int32_t>(), 2);
ASSERT_EQ(out[2].to<int32_t>(), 3);
}
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册