未验证 提交 041f4ab8 编写于 作者: W wangchaochaohu 提交者: GitHub

refine linspace Op for dtype setting(#27071)

上级 92530ca4
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -19,6 +20,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) {
......@@ -35,15 +38,27 @@ template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* start_t = context.Input<framework::Tensor>("Start");
auto* stop_t = context.Input<framework::Tensor>("Stop");
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
auto* num_t = context.Input<framework::Tensor>("Num");
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
framework::TensorCopy(start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n);
framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
T stop = n.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0];
......
......@@ -14,20 +14,38 @@ limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CPULinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0];
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype =
framework::OpKernelType(pre_start->type(), context.GetPlace());
auto stop_dtype =
framework::OpKernelType(pre_stop->type(), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
out->Resize(framework::make_ddim({num}));
......
......@@ -1462,19 +1462,32 @@ def linspace(start, stop, num, dtype=None, name=None):
helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype)
stop_dtype = convert_dtype(tensor_stop.dtype)
out_dtype = convert_dtype(dtype)
if isinstance(start, Variable):
check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace')
check_dtype(start.dtype, 'start',
['float32', 'float64', 'int32', 'int64'], 'linspace')
else:
check_type(start, 'start', (int, float), 'linspace')
if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace')
check_dtype(stop.dtype, 'stop',
['float32', 'float64', 'int32', 'int64'], 'linspace')
else:
check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'linspace')
if ((stop_dtype == "float64" or start_dtype == "float64") and
out_dtype in ["float32", "int32"]) or ((stop_dtype == "int64" or
start_dtype == "int64") and
out_dtype == "int32"):
raise ValueError(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace."
.format(start_dtype, stop_dtype, dtype))
out = helper.create_variable_for_type_inference(dtype=dtype)
......
......@@ -154,16 +154,16 @@ class TestLinspaceOpError(unittest.TestCase):
self.assertRaises(TypeError, test_step_dtype)
def test_start_dtype():
start = fluid.data(shape=[1], dtype="int32", name="start")
start = fluid.data(shape=[1], dtype="float64", name="start")
fluid.layers.linspace(start, 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_dtype)
self.assertRaises(ValueError, test_start_dtype)
def test_end_dtype():
end = fluid.data(shape=[1], dtype="int32", name="end")
end = fluid.data(shape=[1], dtype="float64", name="end")
fluid.layers.linspace(0, end, 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype)
self.assertRaises(ValueError, test_end_dtype)
def test_num_dtype():
num = fluid.data(shape=[1], dtype="int32", name="step")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册