未验证 提交 1288ac29 编写于 作者: W wangchaochaohu 提交者: GitHub

fix expand bug (#20340)

* fix expand bug test=develop

* fix style test=develop

* fix style test=develop

* fix style test=develop

* fix style test=develop
上级 98ec9927
...@@ -226,8 +226,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -226,8 +226,11 @@ REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>, expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>, ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>, ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>); ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand_grad, expand_grad,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>, ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>); ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -18,8 +18,11 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -18,8 +18,11 @@ REGISTER_OP_CUDA_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>, expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>, ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>, ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>); ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
expand_grad, expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>); ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -12077,10 +12077,21 @@ def expand(x, expand_times, name=None): ...@@ -12077,10 +12077,21 @@ def expand(x, expand_times, name=None):
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
# the shape of expanded_2 is [48, 56]. # the shape of expanded_2 is [48, 56].
""" """
if not isinstance(x, Variable):
raise TypeError(
"The type of 'input' in reduce_sum must be Variable, but received %s"
% (type(x)))
if not isinstance(expand_times, (list, tuple, Variable)): if not isinstance(expand_times, (list, tuple, Variable)):
raise ValueError( raise ValueError(
"Input expand_times must be an Variable, python list or tuple.") "Input expand_times must be an Variable, python list or tuple.")
if convert_dtype(
x.dtype) not in ['bool', 'float32', 'float64', 'int32', 'int64']:
raise TypeError(
"The data type of input in expand must be one of bool float32, float64, int32 or int64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True:
raise ValueError(
"expand op bool date type must set the stop_gradient to be False")
helper = LayerHelper('expand', input=x, **locals()) helper = LayerHelper('expand', input=x, **locals())
inputs = {"X": x} inputs = {"X": x}
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
# Situation 1: expand_times is a list(without tensor) # Situation 1: expand_times is a list(without tensor)
...@@ -176,6 +177,36 @@ class TestExpandOpBoolean(OpTest): ...@@ -176,6 +177,36 @@ class TestExpandOpBoolean(OpTest):
self.check_output() self.check_output()
# Situation 56: input x is Integer
class TestExpandOpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {
'X': np.random.randint(
10, size=(2, 4, 5)).astype("int64")
}
self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestExpandError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
expand_times = [2, 2]
self.assertRaises(TypeError, fluid.layers.expand, x1, expand_times)
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.expand, x2, expand_times)
x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool")
x3.stop_gradient = True
self.assertRaises(ValueError, fluid.layers.expand, x3, expand_times)
# Test python API # Test python API
class TestExpandAPI(OpTest): class TestExpandAPI(OpTest):
def test_api(self): def test_api(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册