未验证 提交 3e956865 编写于 作者: L liuyuhui 提交者: GitHub

add cast/concat/assign xpu op (#27911)

* addd

* add cast_op_xpu, test=kunlun

* fix bug for cast_op_xpu,test=kunlun

* add concat_op_xpu, test=kunlun

* slove conflicts, test=kunlun

* fix bug,test=kunlun

* add assign_op_xpu, test=kunlun

* fix bug,test=kunlun

* test=kunlun;test=develop

* fix concat bug,test=kunlun

* fix check_dygraph set in test_concat_op_xpu.py,test=kunlun

* fix error message,test=kunlun
Co-authored-by: Nmapingshuo <mps2012@yeah.net>
上级 2ed84a67
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/assign_op.h"
#include <string>
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
namespace platform {
struct CPUPlace;
struct CUDAPlace;
struct float16;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
class AssignOp : public framework::OperatorWithKernel {
public:
AssignOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("X")) {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// The runtime output shape is determined in kernel.
return;
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const framework::Variable *var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(liym27): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class AssignInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class AssignKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
.AsDispensable();
AddOutput("Out",
"(LoDTensor, SelectedRows or LoDTensorArray) The type of output "
"is the same as input X.");
AddComment(R"DOC(Assign Operator
Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray]
raise error if the type is not listed above.
)DOC");
}
};
template <typename T>
class AssignGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("X"));
}
};
DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, bool,
ops::AssignKernel);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/cast_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename InT>
class CastXPUKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("in_dtype"));
auto out_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("out_dtype"));
auto* in_data = in->data<InT>();
auto numel = in->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = -1;
if (out_type == framework::proto::VarType::FP32) {
auto* out_data = out->mutable_data<float>(context.GetPlace());
r = xpu::cast<InT, float>(dev_ctx.x_context(), in_data, out_data, numel);
} else if (out_type == framework::proto::VarType::INT32) {
auto* out_data = out->mutable_data<int>(context.GetPlace());
r = xpu::cast<InT, int>(dev_ctx.x_context(), in_data, out_data, numel);
} else if (out_type == framework::proto::VarType::INT64) {
auto* out_data = out->mutable_data<int64_t>(context.GetPlace());
r = xpu::cast<InT, int64_t>(dev_ctx.x_context(), in_data, out_data,
numel);
} else {
PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d",
in_type, out_type));
}
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
cast, ops::CastXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::CastXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::CastXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class ConcatXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
int axis = ctx.Attr<int>("axis");
PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument(
"The input should not be null."));
PADDLE_ENFORCE_NE(ctx.HasInput("AxisTensor"), true,
platform::errors::InvalidArgument(
"XPU donot surpport AxisTensor for now"));
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
PADDLE_ENFORCE_GE(
axis, 0, platform::errors::InvalidArgument("concat: axis shoud >= 0!"));
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
platform::errors::InvalidArgument(
"concat: axis shoud < ins[0]->dims()!"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
std::vector<int> choose_idx;
int n = 0;
for (unsigned int i = 0; i < ins.size(); ++i) {
if (ins[i] && ins[i]->numel() > 0) {
choose_idx.push_back(i);
n++;
}
}
PADDLE_ENFORCE_LE(n, 8, platform::errors::InvalidArgument(
"XPU only surpport at most 8 tensors for now"));
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument("No tensor need concat?"));
int h = 1;
int w_except_axis = 1;
for (int i = 0; i < axis; ++i) {
h *= (ins[choose_idx[0]]->dims())[i];
}
for (int i = axis + 1; i < ins[0]->dims().size(); ++i) {
w_except_axis *= (ins[choose_idx[0]]->dims())[i];
}
for (int i = 1; i < n; ++i) {
int hh = 1;
int ww = 1;
for (int j = 0; j < axis; ++j) {
hh *= (ins[choose_idx[i]]->dims())[j];
}
for (int j = axis + 1; j < ins[i]->dims().size(); ++j) {
ww *= (ins[choose_idx[i]]->dims())[j];
}
PADDLE_ENFORCE_EQ(hh, h, platform::errors::InvalidArgument(
"concat: h should be eual!"));
PADDLE_ENFORCE_EQ(ww, w_except_axis,
platform::errors::InvalidArgument(
"concat: w should be eual except for axis!"));
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::unique_ptr<int[]> in_w_host(new int[n]);
std::unique_ptr<const float* []> ptrs(new const float*[n]);
for (int i = 0; i < n; ++i) {
ptrs[i] = ins[choose_idx[i]]->data<T>();
in_w_host[i] = w_except_axis * (ins[choose_idx[i]]->dims())[axis];
}
int r =
xpu::concat<float>(dev_ctx.x_context(), h, (const int*)in_w_host.get(),
n, (const float**)ptrs.get(), out->data<T>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
};
template <typename DeviceContext, typename T>
class ConcatGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
{
auto dx = outs;
auto x = ins;
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
}
PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument(
"The input should not be null."));
auto axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
}
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
} else {
outputs.push_back(nullptr);
}
}
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
"concat_grad: axis shoud >= 0!"));
PADDLE_ENFORCE_LT(axis, out_grad->dims().size(),
platform::errors::InvalidArgument(
"concat_grad: axis shoud < ins[0]->dims()!"));
auto out_grad_stride = framework::stride_numel(out_grad->dims());
int n = outputs.size();
PADDLE_ENFORCE_LE(n, 16,
platform::errors::InvalidArgument(
"XPU only surpport at most 16 tensors for now"));
int h = out_grad_stride[0] / out_grad_stride[axis];
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::unique_ptr<int[]> in_w_host(new int[n]);
std::unique_ptr<float* []> ptrs(new float*[n]);
for (int i = 0; i < n; ++i) {
auto out_stride = framework::stride_numel(outputs[i]->dims());
ptrs[i] = outputs[i]->data<T>();
in_w_host[i] = out_stride[axis];
}
int r = xpu::concat_grad(dev_ctx.x_context(), h, in_w_host.get(), n,
reinterpret_cast<float**>(ptrs.get()),
out_grad->data<T>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
concat_grad,
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
......@@ -253,7 +253,8 @@ class TestConcatAPI(unittest.TestCase):
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
def test_api(self):
x_1 = paddle.fluid.data(shape=[None, 1, 4, 5], dtype='int32', name='x_1')
x_1 = paddle.fluid.data(
shape=[None, 1, 4, 5], dtype='int32', name='x_1')
paddle.concat([x_1, x_1], 0)
input_2 = np.random.random([2, 1, 4, 5]).astype("int32")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
class TestAssignOp(op_test.OpTest):
def setUp(self):
self.op_type = "assign"
x = np.random.random(size=(100, 10)).astype('float32')
self.inputs = {'X': x}
self.outputs = {'Out': x}
def test_forward(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_backward(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestAssignOpWithLoDTensorArray(unittest.TestCase):
def test_assign_LoDTensorArray(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program):
x = fluid.data(name='x', shape=[100, 10], dtype='float32')
x.stop_gradient = False
y = fluid.layers.fill_constant(
shape=[100, 10], dtype='float32', value=1)
z = fluid.layers.elementwise_add(x=x, y=y)
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
init_array = fluid.layers.array_write(x=z, i=i)
array = fluid.layers.assign(init_array)
sums = fluid.layers.array_read(array=init_array, i=i)
mean = fluid.layers.mean(sums)
append_backward(mean)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feed_x = np.random.random(size=(100, 10)).astype('float32')
ones = np.ones((100, 10)).astype('float32')
feed_add = feed_x + ones
res = exe.run(main_program,
feed={'x': feed_x},
fetch_list=[sums.name, x.grad_name])
self.assertTrue(np.allclose(res[0], feed_add))
self.assertTrue(np.allclose(res[1], ones / 1000.0))
class TestAssignOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The type of input must be Variable or numpy.ndarray.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.assign, x1)
# When the type of input is Variable, the dtype of input must be float16, float32, float64, int32, int64, bool.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.assign, x3)
# When the type of input is numpy.ndarray, the dtype of input must be float32, int32.
x4 = np.array([[2.5, 2.5]], dtype='float64')
self.assertRaises(TypeError, fluid.layers.assign, x4)
x5 = np.array([[2.5, 2.5]], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.assign, x5)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import op_test
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestCastOp1(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], ['Out'])
class TestCastOp2(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
def test_check_output(self):
#self.check_output(atol=1e-3)
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-3)
class TestCastOp3(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
def test_check_output(self):
#self.check_output(atol=1e-3)
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-3)
class TestCastOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of cast_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be float32, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
class TestConcatOp(OpTest):
def setUp(self):
self.op_type = "concat"
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
else:
self.actual_axis = self.axis
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.actual_axis)
}
def get_dtype(self):
return "float64"
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
def init_test_data(self):
self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype)
self.axis = 1
class TestConcatOp2(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.axis = 1
@skip_check_grad_ci(
reason="The function 'check_grad' for large inputs is too slow.")
class TestConcatOp3(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype)
self.x1 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.x2 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.axis = 1
def test_check_grad(self):
pass
@skip_check_grad_ci(
reason="This test will meet fetch error when there is a null grad. The detailed information is in PR#17015."
)
class TestConcatOp4(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((0, 3, 4, 5)).astype(self.dtype)
self.axis = 0
def test_check_grad(self):
pass
class TestConcatOp5(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype)
self.axis = -3
class TestConcatOp6(TestConcatOp):
def setUp(self):
self.op_type = "concat"
self.dtype = self.get_dtype()
self.init_test_data()
self.lod = [[20, 80]]
self.out_lod = [[20, 80, 20, 80, 20, 80]]
self.inputs = {
'X': [('x0', (self.x0, self.lod)), ('x1', (self.x1, self.lod)),
('x2', (self.x2, self.lod))]
}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
else:
self.actual_axis = self.axis
out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis)
self.outputs = {'Out': (out, self.out_lod)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['x0'], 'Out')
self.check_grad_with_place(place, ['x1'], 'Out')
self.check_grad_with_place(place, ['x2'], 'Out')
def init_test_data(self):
self.x0 = np.random.random([100]).astype(self.dtype)
self.x1 = np.random.random([100]).astype(self.dtype)
self.x2 = np.random.random([100]).astype(self.dtype)
self.axis = 0
class TestConcatOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of concat_op should be list.
x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1')
fluid.layers.concat(x1)
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16, float32, float64, int32, int64.
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
fluid.layers.concat([x6, x7])
# The type of axis in concat_op should be int or Variable.
def test_axis_type():
fluid.layers.concat([x6, x7], 3.2)
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
fluid.layers.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
class TestConcatAPI(unittest.TestCase):
def test_fluid_api(self):
x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='float32', name='x_1')
fluid.layers.concat([x_1, x_1], 0)
input_2 = np.random.random([2, 1, 4, 5]).astype("float32")
input_3 = np.random.random([2, 2, 4, 5]).astype("float32")
x_2 = fluid.data(shape=[2, 1, 4, 5], dtype='float32', name='x_2')
x_3 = fluid.data(shape=[2, 2, 4, 5], dtype='float32', name='x_3')
positive_1_int32 = fluid.layers.fill_constant([1], "float32", 1)
positive_1_int64 = fluid.layers.fill_constant([1], "float32", 1)
out_1 = fluid.layers.concat(input=[x_2, x_3], axis=1)
out_2 = fluid.layers.concat(input=[x_2, x_3], axis=1)
out_3 = fluid.layers.concat(input=[x_2, x_3], axis=1)
exe = fluid.Executor(place=fluid.XPUPlace(0))
[res_1, res_2, res_3] = exe.run(
fluid.default_main_program(),
feed={"x_1": input_2,
"x_2": input_2,
"x_3": input_3},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
def test_errors(self):
with program_guard(Program(), Program()):
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, paddle.concat, [x2])
# The input dtype of concat_op must be float32.
x4 = fluid.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
# The type of axis in concat_op should be int or Variable.
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
def test_axis_type():
paddle.concat([x6, x7], 3.2)
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
paddle.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册