未验证 提交 6af480ca 编写于 作者: L liym27 提交者: GitHub

Support int64 for op assign_value. test=develop (#23179)

上级 d6f72c4f
......@@ -52,10 +52,13 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"Shape of values.");
AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::VarType::INT32,
framework::proto::VarType::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values")
framework::proto::VarType::FP32,
framework::proto::VarType::INT64});
AddAttr<std::vector<float>>("fp32_values", "store the float32 values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values")
AddAttr<std::vector<int>>("int32_values", "store the int32 values")
.SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values")
.SetDefault({});
AddComment(R"DOC(
AssignValue operator
......@@ -75,4 +78,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
ops::AssignValueKernel<float>,
ops::AssignValueKernel<int64_t>);
......@@ -16,4 +16,5 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
ops::AssignValueKernel<float>,
ops::AssignValueKernel<int64_t>);
......@@ -37,6 +37,9 @@ class AssignValueKernel : public framework::OpKernel<T> {
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
break;
default:
PADDLE_THROW("Unsupported dtype for assign_value_op: %d", dtype);
break;
......
......@@ -502,10 +502,13 @@ def assign(input, output=None):
elif dtype == VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in input.flat]
elif dtype == VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in input.flat]
else:
raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be float32 or int32, but "
"the data type of 'input' must be float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
......
......@@ -97,10 +97,8 @@ class TestAssignOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.assign, x4)
x5 = np.array([[2.5, 2.5]], dtype='float64')
self.assertRaises(TypeError, fluid.layers.assign, x5)
x6 = np.array([[2.5, 2.5]], dtype='int64')
x6 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x6)
x7 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x7)
if __name__ == '__main__':
......
......@@ -14,42 +14,79 @@
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import op_test
import numpy
import unittest
import numpy
import op_test
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.layers as layers
class TestAssignValueOp(op_test.OpTest):
def setUp(self):
self.op_type = "assign_value"
x = numpy.random.random(size=(2, 5)).astype(numpy.float32)
self.inputs = {}
self.outputs = {'Out': x}
self.attrs = {
'shape': x.shape,
'dtype': framework.convert_np_dtype_to_dtype_(x.dtype),
'fp32_values': [float(v) for v in x.flat]
}
self.attrs = {}
self.init_data()
self.attrs["shape"] = self.value.shape
self.attrs["dtype"] = framework.convert_np_dtype_to_dtype_(
self.value.dtype)
self.outputs = {"Out": self.value}
def init_data(self):
self.value = numpy.random.random(size=(2, 5)).astype(numpy.float32)
self.attrs["fp32_values"] = [float(v) for v in self.value.flat]
def test_forward(self):
self.check_output()
class TestAssignValueOp2(TestAssignValueOp):
def init_data(self):
self.value = numpy.random.random(size=(2, 5)).astype(numpy.int32)
self.attrs["int32_values"] = [int(v) for v in self.value.flat]
class TestAssignValueOp3(TestAssignValueOp):
def init_data(self):
self.value = numpy.random.random(size=(2, 5)).astype(numpy.int64)
self.attrs["int64_values"] = [int(v) for v in self.value.flat]
class TestAssignApi(unittest.TestCase):
def setUp(self):
self.init_dtype()
self.value = (
-100 + 200 * numpy.random.random(size=(2, 5))).astype(self.dtype)
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def init_dtype(self):
self.dtype = "float32"
def test_assign(self):
val = (
-100 + 200 * numpy.random.random(size=(2, 5))).astype(numpy.int32)
x = layers.create_tensor(dtype="float32")
layers.assign(input=val, output=x)
exe = fluid.Executor(fluid.CPUPlace())
fetched_x = exe.run(fluid.default_main_program(),
feed={},
fetch_list=[x])[0]
main_program = fluid.Program()
with fluid.program_guard(main_program):
x = layers.create_tensor(dtype=self.dtype)
layers.assign(input=self.value, output=x)
exe = fluid.Executor(self.place)
[fetched_x] = exe.run(main_program, feed={}, fetch_list=[x])
self.assertTrue(
numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
numpy.array_equal(fetched_x, self.value),
"fetch_x=%s val=%s" % (fetched_x, self.value))
self.assertEqual(fetched_x.dtype, self.value.dtype)
class TestAssignApi2(TestAssignApi):
def init_dtype(self):
self.dtype = "int32"
class TestAssignApi3(TestAssignApi):
def init_dtype(self):
self.dtype = "int64"
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册