From 29c4fae112b7b904bb1ee67c43d1b7e91de0d27b Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 7 Apr 2020 12:20:33 +0800 Subject: [PATCH] Tensor value support (#23491) * add support for value tensor support of fill_constant Op --- paddle/fluid/operators/fill_constant_op.cc | 15 ++-- paddle/fluid/operators/fill_constant_op.h | 16 ++++ paddle/fluid/operators/optimizers/adam_op.cc | 13 --- paddle/fluid/operators/optimizers/adam_op.cu | 8 ++ paddle/fluid/operators/optimizers/adam_op.h | 8 ++ python/paddle/fluid/layers/tensor.py | 85 +++++++------------ python/paddle/fluid/layers/utils.py | 49 +++++++++++ .../tests/unittests/test_fill_constant_op.py | 57 ++++++++++++- 8 files changed, 170 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index bfa9182b986..e9a3c6f90a8 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -48,16 +48,6 @@ class FillConstantOp : public framework::OperatorWithKernel { framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } - - framework::OpKernelType GetKernelTypeForVar( - const std::string& var_name, const Tensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { - if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { - return expected_kernel_type; - } - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), tensor.layout()); - } }; class FillConstantOpVarTypeInference : public framework::VarTypeInference { @@ -80,6 +70,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "(vector) The shape of the output") .SetDefault({}); + AddInput("ValueTensor", + "(Tensor, optional) If provided, fill_constant Op will use this " + "as value to set the output Tensor, this has a higher priority " + "than attr(str_value), the shape of this tensor MUST BE [1].") + .AsDispensable(); AddInput("ShapeTensor", "(Tensor), optional). The shape of the output." "It has a higher priority than Attr(shape).") diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 6213565ea74..2fde693cbb0 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -99,6 +99,22 @@ class FillConstantKernel : public framework::OpKernel { value = static_cast(tmp_value); } } + if (ctx.HasInput("ValueTensor")) { + auto *value_tensor = ctx.Input("ValueTensor"); + PADDLE_ENFORCE_EQ( + value_tensor->numel(), 1, + platform::errors::InvalidArgument( + "When use Tensor as value to set Tensor value in fill_cosntant, " + "value input(ValueTensor) size must be 1, but get %d", + value_tensor->numel())); + const T *tensor_data = value_tensor->data(); + framework::Tensor cpu_tensor; + if (platform::is_gpu_place(value_tensor->place())) { + TensorCopySync(*value_tensor, platform::CPUPlace(), &cpu_tensor); + tensor_data = cpu_tensor.data(); + } + value = tensor_data[0]; + } auto shape = GetShape(ctx); if (out_var->IsType()) { diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 86bfd9232a4..8e4cce68acb 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -42,19 +42,6 @@ void AdamOp::InferShape(framework::InferShapeContext *ctx) const { platform::errors::NotFound( "Input(Beta2Pow) of AdamOp should not be null.")); - if (ctx->IsRuntime() && ctx->HasInput("Beta1Tensor")) { - auto beta1 = ctx->Inputs("Beta1Tensor"); - PADDLE_ENFORCE_EQ( - beta1.size(), 1, - platform::errors::InvalidArgument("Input(Beta1Tensor) size must be 1")); - } - if (ctx->IsRuntime() && ctx->HasInput("Beta2Tensor")) { - auto beta2 = ctx->Inputs("Beta2Tensor"); - PADDLE_ENFORCE_EQ( - beta2.size(), 1, - platform::errors::InvalidArgument("Input(Beta2Tensor) size must be 1")); - } - PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true, platform::errors::NotFound( "Output(ParamOut) of AdamOp should not be null.")); diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index fbab8cf063b..5373fe15f6d 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -151,11 +151,19 @@ class AdamOpCUDAKernel : public framework::OpKernel { T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); + PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta1Tensor) size must be 1, but get %d", + beta1_tensor->numel())); beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); } T beta2 = static_cast(ctx.Attr("beta2")); if (ctx.HasInput("Beta2Tensor")) { auto* beta2_tensor = ctx.Input("Beta2Tensor"); + PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta2Tensor) size must be 1, but get %d", + beta2_tensor->numel())); beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 11452480227..ff7075a7fc2 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -406,11 +406,19 @@ class AdamOpKernel : public framework::OpKernel { T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); + PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta1Tensor) size must be 1, but get %d", + beta1_tensor->numel())); beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); } T beta2 = static_cast(ctx.Attr("beta2")); if (ctx.HasInput("Beta2Tensor")) { auto* beta2_tensor = ctx.Input("Beta2Tensor"); + PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, + platform::errors::InvalidArgument( + "Input(Beta2Tensor) size must be 1, but get %d", + beta2_tensor->numel())); beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 9efa313cf05..2386d1f27ca 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -550,8 +550,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): If ``shape`` is an Variable, it should be an 1-D Tensor . dtype(np.dtype|core.VarDesc.VarType|str): Data type of the output tensor which can be float16, float32, float64, int32, int64. - value(float): The constant value used to initialize the Tensor to be created. - force_cpu(True): data should be on CPU if it's true, default value is False. + value(float16|float32|float64|int32|int64|Variable): The constant value used to initialize + the Tensor to be created. If value is an Variable, it should be an 1-D Tensor. + force_cpu(bool): data should be on CPU if it's true, default value is False. out(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of operation. if out is None, a new Varibale will be create to store the result. @@ -579,13 +580,21 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): # attr shape is an Variable Tensor. shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] + + # attr value is an Variable Tensor. + val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0] + data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]] """ - attrs = {'value': float(value), 'force_cpu': force_cpu} - - if convert_dtype(dtype) in ['int64', 'int32']: - attrs['str_value'] = str(int(value)) + inputs = {} + attrs = {'force_cpu': force_cpu} + if isinstance(value, Variable): + inputs['ValueTensor'] = value else: - attrs['str_value'] = str(float(value)) + attrs['value'] = float(value) + if convert_dtype(dtype) in ['int64', 'int32']: + attrs['str_value'] = str(int(value)) + else: + attrs['str_value'] = str(float(value)) if in_dygraph_mode(): if isinstance(shape, (list, tuple)): @@ -596,6 +605,13 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): shape = list(shape.numpy().astype(int)) if out is None: out = _varbase_creator(dtype=dtype) + + if isinstance(value, Variable): + if convert_dtype(dtype) in ['int64', 'int32']: + attrs['str_value'] = str(int(value.numpy())) + else: + attrs['str_value'] = str(float(value.numpy())) + core.ops.fill_constant(out, 'value', float(value), 'force_cpu', force_cpu, 'dtype', out.dtype, 'str_value', attrs['str_value'], @@ -608,55 +624,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') - inputs = {} - attrs = {'value': float(value), 'force_cpu': force_cpu} - - if convert_dtype(dtype) in ['int64', 'int32']: - attrs['str_value'] = str(int(value)) - else: - attrs['str_value'] = str(float(value)) - - def _get_attr_shape(list_shape): - attr_shape = [] - for idx, dim in enumerate(list_shape): - if isinstance(dim, Variable): - attr_shape.append(-1) - else: - attr_shape.append(dim) - return attr_shape - - def _get_shape_tensor(list_shape): - new_shape_tensor = [] - for idx, dim in enumerate(list_shape): - if isinstance(dim, Variable): - dim.stop_gradient = True - check_dtype( - dim.dtype, 'shape[' + str(idx) + ']', ['int32', 'int64'], - 'fill_constant', - '(When type of shape in fill_constant is list or tuple.)') - if convert_dtype(dim.dtype) == 'int64': - dim = cast(x=dim, dtype='int32') - new_shape_tensor.append(dim) - else: - temp_out = helper.create_variable_for_type_inference('int32') - fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) - new_shape_tensor.append(temp_out) - return new_shape_tensor - - if isinstance(shape, Variable): - shape.stop_gradient = True - check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant', - '(When type of shape in fill_constant is Variable.)') - if (convert_dtype(shape.dtype) == 'int64'): - shape = cast(shape, 'int32') - inputs["ShapeTensor"] = shape - elif isinstance(shape, (list, tuple)): - assert len(shape) > 0, ( - "The size of 'shape' in fill_constant can't be zero, " - "but received %s." % len(shape)) - attrs["shape"] = _get_attr_shape(shape) - if utils._contain_var(shape): - inputs['ShapeTensorList'] = _get_shape_tensor(shape) + inputs = utils._get_shape_tensor_inputs( + inputs=inputs, + helper=helper, + attrs=attrs, + shape=shape, + op_type='fill_constant') if out is None: out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 57d2547f694..0bfd95a6c00 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -18,6 +18,8 @@ import copy import six import numpy as np from ..framework import Variable +from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype +from ..layer_helper import LayerHelper def convert_to_list(value, n, name, dtype=np.int): @@ -274,3 +276,50 @@ def _contain_var(list_or_tuple): if isinstance(item, Variable): return True return False + + +def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): + from .tensor import fill_constant, cast + + def _get_attr_shape(list_shape): + attr_shape = [] + for idx, dim in enumerate(list_shape): + if isinstance(dim, Variable): + attr_shape.append(-1) + else: + attr_shape.append(dim) + return attr_shape + + def _get_shape_tensor(list_shape): + new_shape_tensor = [] + for idx, dim in enumerate(list_shape): + if isinstance(dim, Variable): + dim.stop_gradient = True + check_dtype( + dim.dtype, 'shape[' + str(idx) + ']', ['int32', 'int64'], + op_type, + '(When type of shape in' + op_type + 'is list or tuple.)') + if convert_dtype(dim.dtype) == 'int64': + dim = cast(x=dim, dtype='int32') + new_shape_tensor.append(dim) + else: + temp_out = fill_constant([1], 'int32', dim, force_cpu=True) + new_shape_tensor.append(temp_out) + return new_shape_tensor + + if isinstance(shape, Variable): + shape.stop_gradient = True + check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant', + '(When type of shape in' + op_type + ' is Variable.)') + if (convert_dtype(shape.dtype) == 'int64'): + shape = cast(shape, 'int32') + inputs["ShapeTensor"] = shape + elif isinstance(shape, (list, tuple)): + assert len(shape) > 0, ( + "The size of 'shape' in" + op_type + " can't be zero, " + "but received %s." % len(shape)) + attrs["shape"] = _get_attr_shape(shape) + if _contain_var(shape): + inputs['ShapeTensorList'] = _get_shape_tensor(shape) + + return inputs diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index e6a6df6bdac..f87c4d42071 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -212,6 +212,54 @@ class TestFillConstantOp1_ShapeTensor(OpTest): self.check_output() +# Situation 4: value is a tensor +class TestFillConstantOp1_ValueTensor(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + self.init_data() + + self.inputs = { + "ShapeTensor": np.array(self.shape).astype("int32"), + 'ValueTensor': np.array([self.value]).astype("float32") + } + self.attrs = {'value': self.value + 1.0} + self.outputs = {'Out': np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.value = 3.8 + self.dtype = np.float32 + + def test_check_output(self): + self.check_output() + + +# Situation 5: value is a tensor +class TestFillConstantOp2_ValueTensor(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + self.init_data() + + self.inputs = { + "ShapeTensor": np.array(self.shape).astype("int32"), + 'ValueTensor': np.array([self.value]).astype("int32") + } + self.attrs = {'value': self.value, 'dtype': 2} + self.outputs = {'Out': np.full(self.shape, self.value)} + + def init_data(self): + self.shape = [123, 92] + self.value = 3 + self.dtype = np.int32 + + def test_check_output(self): + self.check_output() + + # Test python API class TestFillConstantAPI(unittest.TestCase): def test_api(self): @@ -242,14 +290,18 @@ class TestFillConstantAPI(unittest.TestCase): out_6 = fluid.layers.fill_constant( shape=shape_tensor_int64, dtype=np.float32, value=1.1) + val = fluid.layers.fill_constant(shape=[1], dtype=np.float32, value=1.1) + out_7 = fluid.layers.fill_constant( + shape=shape_tensor_int64, dtype=np.float32, value=val) + exe = fluid.Executor(place=fluid.CPUPlace()) - res_1, res_2, res_3, res_4, res_5, res_6 = exe.run( + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( fluid.default_main_program(), feed={ "shape_tensor_int32": np.array([1, 2]).astype("int32"), "shape_tensor_int64": np.array([1, 2]).astype("int64"), }, - fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6]) + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32")) @@ -257,6 +309,7 @@ class TestFillConstantAPI(unittest.TestCase): assert np.array_equal(res_4, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32")) + assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32")) class TestFillConstantOpError(unittest.TestCase): -- GitLab