提交 fc6ec3b9 编写于 作者: L liym27 提交者: Tao Luo

fill_constant support Tensor; (#20521)

2. fix bug in backward.py: using fill_constant instead of fill_constant_batch_size_like
3. fix bug in ExpandGradOp.

test=develop
上级 0130cc96
......@@ -31,7 +31,6 @@ class ExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
......@@ -162,10 +161,13 @@ class ExpandGradOp : public framework::OperatorWithKernel {
if (expand_times[i] == -1) {
continue;
} else {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expand_times) value.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expand_times) value.");
}
}
}
auto x_grad_name = framework::GradVarName("X");
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -22,9 +22,22 @@ class FillConstantOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of FillConstantOp should not be null.");
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int>(num_ele, -1);
ctx->SetOutputDim("Out", framework::make_ddim(vec_dims));
return;
}
ctx->SetOutputDim("Out", framework::make_ddim(shape));
}
......@@ -35,6 +48,16 @@ class FillConstantOp : public framework::OperatorWithKernel {
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class FillConstantOpVarTypeInference : public framework::VarTypeInference {
......@@ -55,7 +78,18 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddInput("ShapeTensor",
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape).")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1].")
.AsDuplicable()
.AsDispensable();
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
AddAttr<bool>("force_cpu",
......
......@@ -22,6 +22,53 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline framework::DDim GetShape(const framework::ExecutionContext &ctx) {
// 1. shape is a Tensor
if (ctx.HasInput("ShapeTensor")) {
auto *shape_tensor = ctx.Input<framework::LoDTensor>("ShapeTensor");
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto vec_shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
return framework::make_ddim(vec_shape);
}
// 2. shape is a list/tuple containing Tensor
auto shape_tensor_list = ctx.MultiInput<framework::Tensor>("ShapeTensorList");
if (shape_tensor_list.size() > 0) {
std::vector<int> vec_shape;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
"ShapeError: If the element type of 'shape' in FillConstantOp is "
"Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor->dims());
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_shape.push_back(*temp.data<int>());
} else {
vec_shape.push_back(*tensor->data<int>());
}
}
return framework::make_ddim(vec_shape);
}
// 3. shape is a list/tuple without containing Tensor
auto vec_shape = ctx.Attr<std::vector<int64_t>>("shape");
return framework::make_ddim(vec_shape);
}
template <typename T>
class FillConstantKernel : public framework::OpKernel<T> {
public:
......@@ -35,14 +82,14 @@ class FillConstantKernel : public framework::OpKernel<T> {
framework::Variable *out_var = ctx.OutputVar("Out");
auto shape = GetShape(ctx);
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->Resize(
framework::make_ddim(ctx.Attr<std::vector<int64_t>>("shape")));
tensor->Resize(shape);
} else if (out_var->IsType<framework::SelectedRows>()) {
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(
framework::make_ddim(ctx.Attr<std::vector<int64_t>>("shape")));
tensor->Resize(shape);
} else {
PADDLE_THROW(
"fill constant op's output only"
......
......@@ -23,7 +23,7 @@ import logging
from .. import compat as cpt
from . import unique_name
from . import log_helper
import paddle.fluid
__all__ = [
'append_backward',
'gradients',
......@@ -1247,15 +1247,15 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
target = targets[i]
if grad is None:
grad_name = _append_grad_suffix_(target.name)
op_desc = _create_op_desc_("fill_constant_batch_size_like",
{"Input": [target.name]},
target_shape = paddle.fluid.layers.shape(target)
op_desc = _create_op_desc_("fill_constant",
{"ShapeTensor": [target_shape.name]},
{"Out": [grad_name]}, {
"shape": target.shape,
"shape": [],
"value": 1.0,
"dtype": target.dtype,
'input_dim_idx': 0,
'output_dim_idx': 0
})
block.desc.append_op().copy_from(op_desc)
input_grad_names_set.add(grad_name)
else:
......
......@@ -532,7 +532,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
data2 = fluid.layers.fill_constant(shape=[2,1], value=5, dtype='int64', out=data1)
#data1=[[5], [5]] data2=[[5], [5]]
"""
helper = LayerHelper("fill_constant", **locals())
if convert_dtype(dtype) not in [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64'
......@@ -541,6 +540,56 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"The create data type in fill_constant must be one of 'bool', float16, float32,"
"float64, int32 or int64, but received %s." % convert_dtype(
(dtype)))
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
inputs = {}
attrs = {
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
}
def _contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def _get_attr_shape(list_shape):
attr_shape = []
for idx, dim in enumerate(list_shape):
if isinstance(dim, Variable):
attr_shape.append(-1)
else:
attr_shape.append(dim)
return attr_shape
def _get_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs["ShapeTensor"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of 'shape' in fill_constant can't be zero, "
"but received %s." % len(shape))
attrs["shape"] = _get_attr_shape(shape)
if _contain_var(shape):
inputs['ShapeTensorList'] = _get_shape_tensor(shape)
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
......@@ -549,16 +598,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"The create data type in op must be same with out type"
"but received %s and out dtype %s." % (convert_dtype(
(dtype), convert_dtype(out.dtype))))
attrs['dtype'] = out.dtype
helper.append_op(
type='fill_constant',
inputs={},
inputs=inputs,
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
},
attrs=attrs,
stop_gradient=True)
out.stop_gradient = True
return out
......
......@@ -222,6 +222,8 @@ class TestExpandAPI(OpTest):
out_2 = fluid.layers.expand(x, expand_times=[positive_2, 3])
out_3 = fluid.layers.expand(x, expand_times=expand_times)
g0 = fluid.backward.calc_gradient(out_2, x)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(fluid.default_main_program(),
feed={
......
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
# Situation 1: Attr(shape) is a list(without tensor)
class TestFillConstantOp1(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
......@@ -106,10 +107,121 @@ class TestFillConstantOpWithSelectedRows(OpTest):
self.check_with_place(place)
# Situation 2: Attr(shape) is a list(with tensor)
class TestFillConstantOp1_ShapeTensorList(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
'''
self.op_type = "fill_constant"
self.init_data()
shape_tensor_list = []
for index, ele in enumerate(self.shape):
shape_tensor_list.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {"ShapeTensorList": shape_tensor_list}
self.attrs = {'shape': self.infer_shape, 'value': self.value}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [-1, 92]
self.value = 3.8
def test_check_output(self):
self.check_output()
class TestFillConstantOp2_ShapeTensorList(OpTest):
def setUp(self):
'''Test fill_constant op with default value
'''
self.op_type = "fill_constant"
self.init_data()
shape_tensor_list = []
for index, ele in enumerate(self.shape):
shape_tensor_list.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {"ShapeTensorList": shape_tensor_list}
self.attrs = {'shape': self.infer_shape}
self.outputs = {'Out': np.full(self.shape, 0.0)}
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [-1, -1]
def test_check_output(self):
self.check_output()
class TestFillConstantOp3_ShapeTensorList(TestFillConstantOp1_ShapeTensorList):
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [123, -1]
self.value = 10000000000
class TestFillConstantOp4_ShapeTensorList(TestFillConstantOp1_ShapeTensorList):
def init_data(self):
self.shape = [123, 92]
self.infer_shape = [123, -1]
self.value = 3
# Situation 3: shape is a tensor
class TestFillConstantOp1_ShapeTensor(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
'''
self.op_type = "fill_constant"
self.init_data()
self.inputs = {"ShapeTensor": np.array(self.shape).astype("int32")}
self.attrs = {'value': self.value}
self.outputs = {'Out': np.full(self.shape, self.value)}
def init_data(self):
self.shape = [123, 92]
self.value = 3.8
def test_check_output(self):
self.check_output()
# # Test python API
class TestFillConstantAPI(OpTest):
def test_api(self):
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
shape_tensor = fluid.layers.data(
name="shape_tensor",
shape=[2],
append_batch_size=False,
dtype="int32")
out_1 = fluid.layers.fill_constant(
shape=[1, 2], dtype="float32", value=1.1)
out_2 = fluid.layers.fill_constant(
shape=[1, positive_2], dtype="float32", value=1.1)
out_3 = fluid.layers.fill_constant(
shape=shape_tensor, dtype="float32", value=1.1)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(
fluid.default_main_program(),
feed={"shape_tensor": np.array([1, 2]).astype("int32")},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_3, np.full([1, 2], 1.1, dtype="float32"))
class TestFillConstantOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
#for ci coverage
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
self.assertRaises(
ValueError,
......@@ -124,9 +236,10 @@ class TestFillConstantOpError(OpTest):
value=5,
dtype='int16',
out=x1)
# The input dtype of fill_constant must be one of bool, float16,
# The input dtype of fill_constant must be one of bool, float16,
#float32, float64, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
......@@ -141,6 +254,17 @@ class TestFillConstantOpError(OpTest):
dtype='float64',
out=x2)
# test Error of Shape
def test_shape_type():
fluid.layers.fill_constant(shape=1, dtype="float32", value=1)
self.assertRaises(TypeError, test_shape_type)
def test_shape_size():
fluid.layers.fill_constant(shape=[], dtype="float32", value=1)
self.assertRaises(AssertionError, test_shape_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册