未验证 提交 a3c50c42 编写于 作者: B BrilliantYuKaimin 提交者: GitHub

【PaddlePaddle Hackathon 2】9、为 Paddle 新增 logspace API (#41261)

* 增加logspace的算子描述

* 增加logspace的形状推断

* 增加logspace核函数实现

* 在python中增加logspace接口

* 增加logspace单测

* 增加logspace

* Update logspace_kernel.cu

* Update logspace_op.cc

* 调整代码格式

* Update doc of logspace

* Update tensor.py

* Update logspace_op.cc

* Update logspace_kernel.cc

* Update logspace_kernel.cu

* Update test_logspace.py

* 调整 logspace 的位置

* 调整代码格式
上级 885171e3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class LogspaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Start",
"Exponent of first entry in the sequence. It is a tensor of "
"shape [1], should be of type int32, int64, float32 or float64.");
AddInput("Stop",
"Exponent of last entry in the sequence. It is a tensor of "
"shape [1], should be of type int32, int64, float32 or float64.");
AddInput("Num",
"Number of entry in the sequence. It is a tensor of shape [1], "
"should be of type int32.");
AddInput("Base",
"Base of the logarithm function. It is a tensor of shape [1], "
"should be of type int32, int64, float32 or float64.");
AddAttr<int>("dtype", "The output data type.");
AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC(
Return fixed number of logarithmical-evenly spaced values within a given
interval. First entry is exponential of Start with base Base, and last
entry is exponential of Stop with base Base. In the case when Num is 1,
only exponential of Start with base Base is returned. If dtype is int32
or int64, the decimal part of values will be truncated.
Like logspace function of numpy.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor,
PD_INFER_META(phi::LogspaceInferMeta));
REGISTER_OPERATOR(
logspace, ops::LogspaceOp, ops::LogspaceOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
LogspaceInferShapeFunctor);
......@@ -1489,6 +1489,43 @@ void InterpolateInferMeta(
}
}
void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
const MetaTensor& base,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto num_dims = number.dims();
PADDLE_ENFORCE_EQ(
(num_dims.size() == 1) && (num_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
num_dims));
auto b_dims = base.dims();
PADDLE_ENFORCE_EQ(
(b_dims.size() == 1) && (b_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Base) must be [1],"
"but received input shape is [%s].",
b_dims));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(start.dtype());
}
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......
......@@ -228,6 +228,12 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());
void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
const MetaTensor& base,
MetaTensor* out);
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logspace_kernel.h"
#include <cmath>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
namespace phi {
template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
int32_t num = number.data<int32_t>()[0];
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);
T start_data = start_t.template data<T>()[0];
T stop_data = stop_t.template data<T>()[0];
T base_data = base_t.template data<T>()[0];
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of logspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);
if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop_data - start_data)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] =
static_cast<T>(std::pow(base_data, start_data + step * i));
} else {
out_data[i] = static_cast<T>(
std::pow(base_data, stop_data - step * (num - i - 1)));
}
}
} else {
out_data[0] = static_cast<T>(std::pow(base_data, start_data));
}
}
} // namespace phi
PD_REGISTER_KERNEL(logspace,
CPU,
ALL_LAYOUT,
phi::LogspaceKernel,
float,
int32_t,
int64_t,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logspace_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
__global__ void LogspaceKernelInner(
T start, T stop, double step, T base, int64_t size, T* out) {
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] =
static_cast<T>(pow(static_cast<double>(base),
static_cast<double>(start + step * index)));
} else {
out[index] = static_cast<T>(
pow(static_cast<double>(base),
static_cast<double>(stop - step * (size - index - 1))));
}
}
}
template <typename T>
__global__ void LogspaceSpecialKernel(T start, T base, T* out) {
out[0] = static_cast<T>(
pow(static_cast<double>(base), static_cast<double>(start)));
}
template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out) {
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
auto base_t = phi::funcs::TransDataType(ctx, base, dtype);
DenseTensor n_start;
DenseTensor n_stop;
DenseTensor n_num;
DenseTensor n_base;
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start);
T start_data = n_start.data<T>()[0];
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop);
T stop_data = n_stop.data<T>()[0];
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base);
T base_data = n_base.data<T>()[0];
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of logspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);
double step = 0;
auto stream = ctx.stream();
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1);
LogspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, base_data, num, out_data);
} else {
LogspaceSpecialKernel<T><<<grid, block, 0, stream>>>(
start_data, base_data, out_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(logspace,
GPU,
ALL_LAYOUT,
phi::LogspaceKernel,
float,
int32_t,
int64_t,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
const DenseTensor& base,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -89,6 +89,7 @@ from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401
from .tensor.creation import eye # noqa: F401
from .tensor.creation import linspace # noqa: F401
from .tensor.creation import logspace # noqa: F401
from .tensor.creation import ones # noqa: F401
from .tensor.creation import ones_like # noqa: F401
from .tensor.creation import zeros # noqa: F401
......@@ -591,6 +592,7 @@ __all__ = [ # noqa
'sqrt',
'randperm',
'linspace',
'logspace',
'reshape',
'reshape_',
'reverse',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
class TestLogspaceOpCommonCase(OpTest):
def setUp(self):
self.op_type = "logspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([0]).astype(dtype),
'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(dtype),
}
self.attrs = {'dtype': int(paddle.float32)}
self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLogspaceOpReverseCase(OpTest):
def setUp(self):
self.op_type = "logspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([2]).astype(dtype)
}
self.attrs = {'dtype': int(paddle.float32)}
self.outputs = {'Out': np.power(2, np.arange(10, -1, -1)).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLogspaceOpNumOneCase(OpTest):
def setUp(self):
self.op_type = "logspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Num': np.array([1]).astype('int32'),
'Base': np.array([2]).astype(dtype)
}
self.attrs = {'dtype': int(paddle.float32)}
self.outputs = {'Out': np.power(2, np.array(10)).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLogspaceOpMinusBaseCase(OpTest):
def setUp(self):
self.op_type = "logspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([0]).astype(dtype),
'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([-2]).astype(dtype),
}
self.attrs = {'dtype': int(paddle.float32)}
self.outputs = {'Out': np.power(-2, np.arange(0, 11)).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLogspaceOpZeroBaseCase(OpTest):
def setUp(self):
self.op_type = "logspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([0]).astype(dtype),
'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32'),
'Base': np.array([0]).astype(dtype),
}
self.attrs = {'dtype': int(paddle.float32)}
self.outputs = {'Out': np.power(0, np.arange(0, 11)).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLogspaceAPI(unittest.TestCase):
def test_variable_input1(self):
paddle.enable_static()
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
base = paddle.full(shape=[1], fill_value=2, dtype='float32')
out = paddle.logspace(start, stop, num, base, dtype='float32')
exe = paddle.static.Executor()
res = exe.run(prog, fetch_list=[out])
np_res = np.logspace(0, 10, 5, base=2, dtype='float32')
self.assertEqual((res == np_res).all(), True)
paddle.disable_static()
def test_variable_input2(self):
paddle.disable_static()
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
base = paddle.full(shape=[1], fill_value=2, dtype='float32')
out = paddle.logspace(start, stop, num, base, dtype='float32')
np_res = np.logspace(0, 10, 5, base=2, dtype='float32')
self.assertEqual((out.numpy() == np_res).all(), True)
paddle.enable_static()
def test_dtype(self):
paddle.enable_static()
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
out_1 = paddle.logspace(0, 10, 5, 2, dtype='float32')
out_2 = paddle.logspace(0, 10, 5, 2, dtype=np.float32)
exe = paddle.static.Executor()
res_1, res_2 = exe.run(prog, fetch_list=[out_1, out_2])
assert np.array_equal(res_1, res_2)
paddle.disable_static()
def test_name(self):
with paddle.static.program_guard(paddle.static.Program()):
out = paddle.logspace(
0, 10, 5, 2, dtype='float32', name='logspace_res')
assert 'logspace_res' in out.name
def test_imperative(self):
paddle.disable_static()
out1 = paddle.logspace(0, 10, 5, 2, dtype='float32')
np_out1 = np.logspace(0, 10, 5, base=2, dtype='float32')
out2 = paddle.logspace(0, 10, 5, 2, dtype='int32')
np_out2 = np.logspace(0, 10, 5, base=2, dtype='int32')
out3 = paddle.logspace(0, 10, 200, 2, dtype='int32')
np_out3 = np.logspace(0, 10, 200, base=2, dtype='int32')
paddle.enable_static()
self.assertEqual((out1.numpy() == np_out1).all(), True)
self.assertEqual((out2.numpy() == np_out2).all(), True)
self.assertEqual((out3.numpy() == np_out3).all(), True)
class TestLogspaceOpError(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
def test_dtype():
paddle.logspace(0, 10, 1, 2, dtype="int8")
self.assertRaises(TypeError, test_dtype)
def test_dtype1():
paddle.logspace(0, 10, 1.33, 2, dtype="int32")
self.assertRaises(TypeError, test_dtype1)
def test_start_type():
paddle.logspace([0], 10, 1, 2, dtype="float32")
self.assertRaises(TypeError, test_start_type)
def test_end_type():
paddle.logspace(0, [10], 1, 2, dtype="float32")
self.assertRaises(TypeError, test_end_type)
def test_num_type():
paddle.logspace(0, 10, [0], 2, dtype="float32")
self.assertRaises(TypeError, test_num_type)
def test_start_dtype():
start = paddle.static.data(
shape=[1], dtype="float64", name="start")
paddle.logspace(start, 10, 1, 2, dtype="float32")
self.assertRaises(ValueError, test_start_dtype)
def test_end_dtype():
end = paddle.static.data(shape=[1], dtype="float64", name="end")
paddle.logspace(0, end, 1, 2, dtype="float32")
self.assertRaises(ValueError, test_end_dtype)
def test_num_dtype():
num = paddle.static.data(
shape=[1], dtype="float32", name="step")
paddle.logspace(0, 10, num, 2, dtype="float32")
self.assertRaises(TypeError, test_num_dtype)
def test_base_dtype():
base = paddle.static.data(
shape=[1], dtype="float64", name="end")
paddle.logspace(0, 10, 1, base, dtype="float32")
self.assertRaises(ValueError, test_base_dtype)
if __name__ == "__main__":
unittest.main()
......@@ -146,6 +146,130 @@ def linspace(start, stop, num, dtype=None, name=None):
return out
def logspace(start, stop, num, base=10.0, dtype=None, name=None):
r"""
Return fixed number of logarithmical-evenly spaced values within the interval \
:math:`[base^{start}, base^{stop}]`.
Notes:
This API does not compute the gradient.
Args:
start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \
the sequence. It is a scalar, or a Tensor of shape [1] with input data \
type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \
sequence. It is a scalar, or a Tensor of shape [1] with input data \
type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \
It is an int scalar, or a Tensor of shape [1] with data type int32.
base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \
It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \
float32 or float64.
dtype(np.dtype|str, optional): The data type of output tensor, it could be \
int32, int64, float32 or float64. Default: if None, the data type is float32. \
name(str, optional): Normally there is no need for user to set this property. \
For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns:
Tensor: The output data type will be float32, float64. The 1-D tensor with \
fixed number of logarithmical-evenly spaced values, the data shape of this \
tensor is :math:`[num]`. If the :attr:`num` is set 1, the output tensor \
just has the value with exponential of :attr:`start` with base :attr:`base`.
Examples:
.. code-block:: python
:name: logspace-example
import paddle
data = paddle.logspace(0, 10, 5, 2, 'float32')
# [1. , 5.65685415 , 32. , 181.01933289, 1024. ]
data = paddle.logspace(0, 10, 1, 2, 'float32')
# [1.]
"""
if dtype is None:
dtype = 'float32'
tensor_num = num
tensor_start = start
tensor_stop = stop
tensor_base = base
if not isinstance(num, Variable):
check_type(num, 'num', (int), 'logspace')
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable):
with device_guard("cpu"):
tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable):
with device_guard("cpu"):
tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable):
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if not isinstance(base, Variable):
with device_guard("cpu"):
tensor_base = fill_constant([1], dtype, base)
if _non_static_mode():
return _C_ops.logspace(tensor_start, tensor_stop, tensor_num,
tensor_base, 'dtype', dtype)
helper = LayerHelper("logspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype)
stop_dtype = convert_dtype(tensor_stop.dtype)
base_dtype = convert_dtype(tensor_base.dtype)
out_dtype = convert_dtype(dtype)
if isinstance(start, Variable):
check_dtype(start.dtype, 'start',
['float32', 'float64', 'int32', 'int64'], 'logspace')
else:
check_type(start, 'start', (int, float), 'logspace')
if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop',
['float32', 'float64', 'int32', 'int64'], 'logspace')
else:
check_type(stop, 'stop', (int, float), 'logspace')
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'logspace')
if isinstance(base, Variable):
check_dtype(base.dtype, 'base',
['float32', 'float64', 'int32', 'int64'], 'logspace')
else:
check_type(base, 'base', (int, float), 'logspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'logspace')
if ((stop_dtype == "float64" or start_dtype == "float64"
or base_dtype == "float64")
and out_dtype in ["float32", "int32"]) or \
((stop_dtype == "int64" or start_dtype == "int64"
or base_dtype == "int64")
and out_dtype == "int32"):
raise ValueError(
"The dtype of start/stop/base is {}/{}/{} but the attr(dtype) of logspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of logspace."
.format(start_dtype, stop_dtype, base_dtype, dtype))
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='logspace',
inputs={
'Start': tensor_start,
'Stop': tensor_stop,
'Num': tensor_num,
'Base': tensor_base
},
attrs={'dtype': dtype},
outputs={'Out': [out]})
if isinstance(num, int):
out.desc.set_shape((num, ))
return out
@dygraph_only
def to_tensor(data, dtype=None, place=None, stop_gradient=True):
r"""
......
......@@ -306,6 +306,7 @@ STATIC_MODE_TESTING_LIST = [
'test_linear_interp_op',
'test_linear_interp_v2_op',
'test_linspace',
'test_logspace',
'test_load_op',
'test_load_vars_shape_check',
'test_locality_aware_nms_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册