diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu index 8aca892a81d41b1e0a9f7f9c14169c2817ae9452..793253b6b8894de8d89b301921383ebfd53d66fc 100644 --- a/paddle/fluid/operators/linspace_op.cu +++ b/paddle/fluid/operators/linspace_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/linspace_op.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -19,6 +20,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template __global__ void LinspaceKernel(T start, double step, int64_t size, T* out) { CUDA_KERNEL_LOOP(index, size) { @@ -35,15 +38,27 @@ template class CUDALinspaceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* start_t = context.Input("Start"); - auto* stop_t = context.Input("Stop"); + auto* pre_start = context.Input("Start"); + auto* pre_stop = context.Input("Stop"); auto* num_t = context.Input("Num"); auto* out = context.Output("Out"); + auto dtype = static_cast( + context.Attr("dtype")); + + Tensor start_t; + Tensor stop_t; + auto start_dtype = + framework::OpKernelType(pre_start->type(), context.GetPlace()); + auto stop_dtype = + framework::OpKernelType(pre_stop->type(), context.GetPlace()); + auto out_dtype = framework::OpKernelType(dtype, context.GetPlace()); + framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); + framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); framework::Tensor n; - framework::TensorCopy(*start_t, platform::CPUPlace(), &n); + framework::TensorCopy(start_t, platform::CPUPlace(), &n); T start = n.data()[0]; - framework::TensorCopy(*stop_t, platform::CPUPlace(), &n); + framework::TensorCopy(stop_t, platform::CPUPlace(), &n); T stop = n.data()[0]; framework::TensorCopy(*num_t, platform::CPUPlace(), &n); int32_t num = n.data()[0]; diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h index 9fb4960375ed7be60598d558c65310bd4a4b84bc..898f611f864dc8bfac2ba7e41b91f5f5bbe524ab 100644 --- a/paddle/fluid/operators/linspace_op.h +++ b/paddle/fluid/operators/linspace_op.h @@ -14,20 +14,38 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class CPULinspaceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - T start = context.Input("Start")->data()[0]; - T stop = context.Input("Stop")->data()[0]; + auto* pre_start = context.Input("Start"); + auto* pre_stop = context.Input("Stop"); int32_t num = context.Input("Num")->data()[0]; auto* out = context.Output("Out"); + auto dtype = static_cast( + context.Attr("dtype")); + + Tensor start_t; + Tensor stop_t; + auto start_dtype = + framework::OpKernelType(pre_start->type(), context.GetPlace()); + auto stop_dtype = + framework::OpKernelType(pre_stop->type(), context.GetPlace()); + auto out_dtype = framework::OpKernelType(dtype, context.GetPlace()); + framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); + framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); + + T start = start_t.data()[0]; + T stop = stop_t.data()[0]; PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); out->Resize(framework::make_ddim({num})); diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 75b2bf849ccfacf1ff92707f06f8526517d73037..a90551c1b7b4fd45ae9a0e1cfa225a87db811295 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1462,19 +1462,32 @@ def linspace(start, stop, num, dtype=None, name=None): helper = LayerHelper("linspace", **locals()) + start_dtype = convert_dtype(tensor_start.dtype) + stop_dtype = convert_dtype(tensor_stop.dtype) + out_dtype = convert_dtype(dtype) if isinstance(start, Variable): - check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace') + check_dtype(start.dtype, 'start', + ['float32', 'float64', 'int32', 'int64'], 'linspace') else: check_type(start, 'start', (int, float), 'linspace') if isinstance(stop, Variable): - check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace') + check_dtype(stop.dtype, 'stop', + ['float32', 'float64', 'int32', 'int64'], 'linspace') else: check_type(stop, 'stop', (int, float), 'linspace') if isinstance(num, Variable): check_dtype(num.dtype, 'num', ['int32'], 'linspace') check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], 'linspace') + if ((stop_dtype == "float64" or start_dtype == "float64") and + out_dtype in ["float32", "int32"]) or ((stop_dtype == "int64" or + start_dtype == "int64") and + out_dtype == "int32"): + raise ValueError( + "The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, " + "which may cause data type overflows. Please reset attr(dtype) of linspace." + .format(start_dtype, stop_dtype, dtype)) out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py index 6d1f42111eebff0f469317ddf2a9ec7698a7ae1e..03cb84ec99e0259a33a086c3d3e5a71abea09d2b 100644 --- a/python/paddle/fluid/tests/unittests/test_linspace.py +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase): self.assertRaises(TypeError, test_step_dtype) def test_start_dtype(): - start = fluid.data(shape=[1], dtype="int32", name="start") + start = fluid.data(shape=[1], dtype="float64", name="start") fluid.layers.linspace(start, 10, 1, dtype="float32") - self.assertRaises(TypeError, test_start_dtype) + self.assertRaises(ValueError, test_start_dtype) def test_end_dtype(): - end = fluid.data(shape=[1], dtype="int32", name="end") + end = fluid.data(shape=[1], dtype="float64", name="end") fluid.layers.linspace(0, end, 1, dtype="float32") - self.assertRaises(TypeError, test_end_dtype) + self.assertRaises(ValueError, test_end_dtype) def test_num_dtype(): num = fluid.data(shape=[1], dtype="int32", name="step")