未验证 提交 e2897ba1 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16432 from zhoukunsheng/linspace

add linspace op
......@@ -275,6 +275,7 @@ paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, de
paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '2e53e83127dbfd86e7098bdfe9a549e8'))
paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '0a437011c3906079fd8947ed3e52d292'))
paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '2ec937ede953ded2fdff2675883900bb'))
paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '495e21e9a848c2d075a102802fc67756'))
paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.While.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.Switch.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/linspace_op.h"
namespace paddle {
namespace operators {
class LinspaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Start"),
"Input(Start) of LinspaceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Stop"),
"Input(Stop) of LinspaceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Num"),
"Input(Num) of LinspaceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(OUt) of LinspaceOp should not be null.");
auto s_dims = ctx->GetInputDim("Start");
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1),
"The shape of Input(Start) should be [1].");
auto e_dims = ctx->GetInputDim("Stop");
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1),
"The shape of Input(Stop) should be [1].");
auto step_dims = ctx->GetInputDim("Num");
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1),
"The shape of Input(Num) should be [1].");
ctx->SetOutputDim("Out", {-1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
ctx.Input<framework::Tensor>("Start")->type(), ctx.device_context(),
layout_, library_);
}
};
class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Start",
"First entry in the sequence. It is a tensor of shape [1], should "
"be of type float32 or float64.");
AddInput("Stop",
"Last entry in the sequence. It is a tensor of shape [1], should "
"be of type float32 or float64.");
AddInput("Num",
"Number of entry in the sequence. It is a tensor of shape [1], "
"should be of type int32.");
AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC(
Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker);
REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>,
ops::CPULinspaceKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = start;
}
template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* start_t = context.Input<framework::Tensor>("Start");
auto* stop_t = context.Input<framework::Tensor>("Stop");
auto* num_t = context.Input<framework::Tensor>("Num");
auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n;
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
T start = n.data<T>()[0];
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n);
T stop = n.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
int32_t num = n.data<int32_t>()[0];
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
T step = 0;
if (num != 1) {
step = (stop - start) / (num - 1);
}
auto stream = context.cuda_device_context().stream();
int block = 512;
int grid = (num + block - 1) / block;
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
ops::CUDALinspaceKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
class CPULinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0];
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
auto* out = context.Output<framework::Tensor>("Out");
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) {
T step = (stop - start) / (num - 1);
T value = start;
for (int i = 0; i < num; ++i) {
out_data[i] = value;
value += step;
}
} else {
out_data[0] = start;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -24,26 +24,11 @@ from .layer_function_generator import templatedoc
import numpy
__all__ = [
'create_tensor',
'create_parameter',
'create_global_var',
'cast',
'tensor_array_to_tensor',
'concat',
'sums',
'assign',
'fill_constant_batch_size_like',
'fill_constant',
'argmin',
'argmax',
'argsort',
'ones',
'zeros',
'reverse',
'has_inf',
'has_nan',
'isfinite',
'range',
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite',
'range', 'linspace'
]
......@@ -826,3 +811,45 @@ def range(start, end, step, dtype):
'Step': step},
outputs={'Out': [out]})
return out
def linspace(start, stop, num, dtype):
"""
Return fixed number of evenly spaced values within a given interval.
First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy.
Args:
start(float|Variable): First entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'.
stop(float|Variable): Last entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'.
num(int|Variable): Number of entry in the sequence. It is an int scalar, or a tensor of shape [1] with type int32.
dtype(string): 'float32'|'float64', the data type of the output tensor.
Returns:
Variable: The tensor variable storing a 1-D tensor.
Examples:
.. code-block:: python
data = fluid.layers.linspace(0, 10, 5, 'float32') # [0.0, 2.5, 5.0, 7.5, 10.0]
data = fluid.layers.linspace(0, 10, 1, 'float32') # [0.0]
"""
helper = LayerHelper("linspace", **locals())
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable):
stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable):
num = fill_constant([1], 'int32', num)
out = helper.create_variable_for_type_inference(dtype=start.dtype)
helper.append_op(
type='linspace',
inputs={'Start': start,
'Stop': stop,
'Num': num},
outputs={'Out': [out]})
return out
......@@ -1925,6 +1925,13 @@ class TestBook(LayerTest):
out = layers.flatten(x, axis=1, name="flatten")
return (out)
def test_linspace(self):
program = Program()
with program_guard(program):
out = layers.linspace(20, 10, 5, 'float64')
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
class TestLinspaceOpCommonCase(OpTest):
def setUp(self):
self.op_type = "linspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([0]).astype(dtype),
'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32')
}
self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLinspaceOpReverseCase(OpTest):
def setUp(self):
self.op_type = "linspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Num': np.array([11]).astype('int32')
}
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLinspaceOpNumOneCase(OpTest):
def setUp(self):
self.op_type = "linspace"
dtype = 'float32'
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Num': np.array([1]).astype('int32')
}
self.outputs = {'Out': np.array(10, dtype=dtype)}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册