未验证 提交 f3d7db98 编写于 作者: W wawltor 提交者: GitHub

Add the support of bool list for assign_value op (#23774)

* Add the support of bool list for assign value, test=develop
* Fix the assign op test case for bool dtype, test=develop
上级 03e737ac
...@@ -83,6 +83,29 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst); ...@@ -83,6 +83,29 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst);
// The implementation of template functions. // The implementation of template functions.
// //
template <typename T>
void TensorFromArray(const T* src, const size_t& array_size,
const platform::DeviceContext& ctx, Tensor* dst) {
auto dst_place = ctx.GetPlace();
auto src_ptr = static_cast<const void*>(src);
platform::CPUPlace src_place;
dst->Resize({static_cast<int64_t>(array_size)});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>(dst_place));
auto size = array_size * sizeof(T);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) { // NOLINT
memory::Copy(
boost::get<platform::CUDAPlace>(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
template <typename T> template <typename T>
void TensorFromVector(const std::vector<T>& src, void TensorFromVector(const std::vector<T>& src,
const platform::DeviceContext& ctx, Tensor* dst) { const platform::DeviceContext& ctx, Tensor* dst) {
......
...@@ -52,9 +52,12 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,9 +52,12 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) " "(vector<int>) "
"Shape of values."); "Shape of values.");
AddAttr<int>("dtype", "data type of values") AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::VarType::INT32, .InEnum({framework::proto::VarType::BOOL,
framework::proto::VarType::INT32,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
framework::proto::VarType::INT64}); framework::proto::VarType::INT64});
AddAttr<std::vector<int>>("bool_values", "store the bool values")
.SetDefault({});
AddAttr<std::vector<float>>("fp32_values", "store the float32 values") AddAttr<std::vector<float>>("fp32_values", "store the float32 values")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int32 values") AddAttr<std::vector<int>>("int32_values", "store the int32 values")
...@@ -78,6 +81,7 @@ REGISTER_OPERATOR( ...@@ -78,6 +81,7 @@ REGISTER_OPERATOR(
assign_value, ops::AssignValueOp, ops::AssignValueOpMaker, assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>, REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<bool>,
ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>, ops::AssignValueKernel<float>,
ops::AssignValueKernel<int64_t>); ops::AssignValueKernel<int64_t>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/assign_value_op.h" #include "paddle/fluid/operators/assign_value_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>, REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<bool>,
ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>, ops::AssignValueKernel<float>,
ops::AssignValueKernel<int64_t>); ops::AssignValueKernel<int64_t>);
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -22,6 +23,37 @@ ...@@ -22,6 +23,37 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T>
typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor(
const char* value_name, framework::Tensor* out,
const framework::ExecutionContext& ctx) {
// If attribute value dtype is vector<bool>, it will be converted to
// vector<int>.
// at the same time, we can not use vector<bool> to hold the value, because
// the c++ use bit value to replace byte value.
auto values = ctx.Attr<std::vector<int>>(value_name);
framework::TensorFromVector(values, ctx.device_context(), out);
// use the array to replace to vector
bool* array_ptr = new T[values.size()];
for (unsigned int i = 0; i < values.size(); i++) {
array_ptr[i] = static_cast<T>(values[i]);
}
framework::TensorFromArray(array_ptr, values.size(), ctx.device_context(),
out);
delete[] array_ptr;
}
template <typename T>
typename std::enable_if<!std::is_same<T, bool>::value>::type
CopyVecotorToTensor(const char* value_name, framework::Tensor* out,
const framework::ExecutionContext& ctx) {
auto values = ctx.Attr<std::vector<T>>(value_name);
framework::TensorFromVector(values, ctx.device_context(), out);
}
template <typename T> template <typename T>
class AssignValueKernel : public framework::OpKernel<T> { class AssignValueKernel : public framework::OpKernel<T> {
public: public:
...@@ -31,6 +63,9 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -31,6 +63,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr; const char* value_name = nullptr;
switch (dtype) { switch (dtype) {
case framework::proto::VarType::BOOL:
value_name = "bool_values";
break;
case framework::proto::VarType::INT32: case framework::proto::VarType::INT32:
value_name = "int32_values"; value_name = "int32_values";
break; break;
...@@ -44,8 +79,7 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -44,8 +79,7 @@ class AssignValueKernel : public framework::OpKernel<T> {
PADDLE_THROW("Unsupported dtype for assign_value_op: %d", dtype); PADDLE_THROW("Unsupported dtype for assign_value_op: %d", dtype);
break; break;
} }
auto values = ctx.Attr<std::vector<T>>(value_name); CopyVecotorToTensor<T>(value_name, out, ctx);
framework::TensorFromVector(values, ctx.device_context(), out);
out->Resize(framework::make_ddim(shape)); out->Resize(framework::make_ddim(shape));
} }
}; };
......
...@@ -526,7 +526,10 @@ def assign(input, output=None): ...@@ -526,7 +526,10 @@ def assign(input, output=None):
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}) type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
elif isinstance(input, numpy.ndarray): elif isinstance(input, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(input.dtype) dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == VarDesc.VarType.FP32: if dtype == VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in input.flat]
elif dtype == VarDesc.VarType.FP32:
value_name = "fp32_values" value_name = "fp32_values"
values = [float(v) for v in input.flat] values = [float(v) for v in input.flat]
elif dtype == VarDesc.VarType.INT32: elif dtype == VarDesc.VarType.INT32:
...@@ -538,7 +541,7 @@ def assign(input, output=None): ...@@ -538,7 +541,7 @@ def assign(input, output=None):
else: else:
raise TypeError( raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, " "When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be float32, int32 or int64, but " "the data type of 'input' must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype)) "received %s." % convert_dtype(dtype))
if input.size > 1024 * 1024: if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider " raise ValueError("The size of input is too big. Please consider "
......
...@@ -93,12 +93,10 @@ class TestAssignOpError(unittest.TestCase): ...@@ -93,12 +93,10 @@ class TestAssignOpError(unittest.TestCase):
x3 = fluid.layers.data(name='x3', shape=[4], dtype="uint8") x3 = fluid.layers.data(name='x3', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.assign, x3) self.assertRaises(TypeError, fluid.layers.assign, x3)
# When the type of input is numpy.ndarray, the dtype of input must be float32, int32. # When the type of input is numpy.ndarray, the dtype of input must be float32, int32.
x4 = np.array([[2.5, 2.5]], dtype='bool') x4 = np.array([[2.5, 2.5]], dtype='float64')
self.assertRaises(TypeError, fluid.layers.assign, x4) self.assertRaises(TypeError, fluid.layers.assign, x4)
x5 = np.array([[2.5, 2.5]], dtype='float64') x5 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x5) self.assertRaises(TypeError, fluid.layers.assign, x5)
x6 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x6)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -54,6 +54,13 @@ class TestAssignValueOp3(TestAssignValueOp): ...@@ -54,6 +54,13 @@ class TestAssignValueOp3(TestAssignValueOp):
self.attrs["int64_values"] = [int(v) for v in self.value.flat] self.attrs["int64_values"] = [int(v) for v in self.value.flat]
class TestAssignValueOp4(TestAssignValueOp):
def init_data(self):
self.value = numpy.random.choice(
a=[False, True], size=(2, 5)).astype(numpy.bool)
self.attrs["bool_values"] = [bool(v) for v in self.value.flat]
class TestAssignApi(unittest.TestCase): class TestAssignApi(unittest.TestCase):
def setUp(self): def setUp(self):
self.init_dtype() self.init_dtype()
...@@ -89,5 +96,17 @@ class TestAssignApi3(TestAssignApi): ...@@ -89,5 +96,17 @@ class TestAssignApi3(TestAssignApi):
self.dtype = "int64" self.dtype = "int64"
class TestAssignApi4(TestAssignApi):
def setUp(self):
self.init_dtype()
self.value = numpy.random.choice(
a=[False, True], size=(2, 5)).astype(numpy.bool)
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def init_dtype(self):
self.dtype = "bool"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册