未验证 提交 a497123e 编写于 作者: W wangchaochaohu 提交者: GitHub

refine linspace Op for dtype setting(#27071) (#27086)

上级 1d1bebf0
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h" #include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -19,6 +20,8 @@ limitations under the License. */ ...@@ -19,6 +20,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) { __global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { CUDA_KERNEL_LOOP(index, size) {
...@@ -35,15 +38,27 @@ template <typename T> ...@@ -35,15 +38,27 @@ template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> { class CUDALinspaceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* start_t = context.Input<framework::Tensor>("Start"); auto* pre_start = context.Input<framework::Tensor>("Start");
auto* stop_t = context.Input<framework::Tensor>("Stop"); auto* pre_stop = context.Input<framework::Tensor>("Stop");
auto* num_t = context.Input<framework::Tensor>("Num"); auto* num_t = context.Input<framework::Tensor>("Num");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n; framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n); framework::TensorCopy(start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0]; T start = n.data<T>()[0];
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n); framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
T stop = n.data<T>()[0]; T stop = n.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n); framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0]; int32_t num = n.data<int32_t>()[0];
......
...@@ -14,20 +14,38 @@ limitations under the License. */ ...@@ -14,20 +14,38 @@ limitations under the License. */
#pragma once #pragma once
#include <functional> #include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class CPULinspaceKernel : public framework::OpKernel<T> { class CPULinspaceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
T start = context.Input<framework::Tensor>("Start")->data<T>()[0]; auto* pre_start = context.Input<framework::Tensor>("Start");
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0]; auto* pre_stop = context.Input<framework::Tensor>("Stop");
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0]; int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
out->Resize(framework::make_ddim({num})); out->Resize(framework::make_ddim({num}));
......
...@@ -1477,19 +1477,32 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1477,19 +1477,32 @@ def linspace(start, stop, num, dtype=None, name=None):
helper = LayerHelper("linspace", **locals()) helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype)
stop_dtype = convert_dtype(tensor_stop.dtype)
out_dtype = convert_dtype(dtype)
if isinstance(start, Variable): if isinstance(start, Variable):
check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace') check_dtype(start.dtype, 'start',
['float32', 'float64', 'int32', 'int64'], 'linspace')
else: else:
check_type(start, 'start', (int, float), 'linspace') check_type(start, 'start', (int, float), 'linspace')
if isinstance(stop, Variable): if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace') check_dtype(stop.dtype, 'stop',
['float32', 'float64', 'int32', 'int64'], 'linspace')
else: else:
check_type(stop, 'stop', (int, float), 'linspace') check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable): if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace') check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'linspace') 'linspace')
if ((stop_dtype == "float64" or start_dtype == "float64") and
out_dtype in ["float32", "int32"]) or ((stop_dtype == "int64" or
start_dtype == "int64") and
out_dtype == "int32"):
raise ValueError(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace."
.format(start_dtype, stop_dtype, dtype))
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
...@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase): ...@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase):
self.assertRaises(TypeError, test_step_dtype) self.assertRaises(TypeError, test_step_dtype)
def test_start_dtype(): def test_start_dtype():
start = fluid.data(shape=[1], dtype="int32", name="start") start = fluid.data(shape=[1], dtype="float64", name="start")
fluid.layers.linspace(start, 10, 1, dtype="float32") fluid.layers.linspace(start, 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_dtype) self.assertRaises(ValueError, test_start_dtype)
def test_end_dtype(): def test_end_dtype():
end = fluid.data(shape=[1], dtype="int32", name="end") end = fluid.data(shape=[1], dtype="float64", name="end")
fluid.layers.linspace(0, end, 1, dtype="float32") fluid.layers.linspace(0, end, 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype) self.assertRaises(ValueError, test_end_dtype)
def test_num_dtype(): def test_num_dtype():
num = fluid.data(shape=[1], dtype="int32", name="step") num = fluid.data(shape=[1], dtype="int32", name="step")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册