未验证 提交 b7355d8e 编写于 作者: R ronnywang 提交者: GitHub

[NPU] add broadcast supporting for elementwise_add_op_npu (#34057)

* add broadcast supporting for elementwise_add

* add broadcast supporting for elementwise_add

* add more tests

* remove the redundant code

* update

* fix place error in unittest

* remove skip.If
上级 338f9e05
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
......@@ -27,12 +28,37 @@ template <typename T>
class ElementwiseAddNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("Add", {*x, *y}, {*out}, {});
int axis = ctx.Attr<int>("axis");
bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute =
y_dims == framework::slice_ddim(x_dims, axis, x_dims.size());
} else {
direct_compute =
x_dims == framework::slice_ddim(y_dims, axis, y_dims.size());
}
Tensor transformed_x, transformed_y;
if (direct_compute) {
transformed_x.ShareDataWith(*x);
transformed_y.ShareDataWith(*y);
} else {
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
}
const auto& runner =
NpuOpRunner("Add", {transformed_x, transformed_y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -44,109 +70,75 @@ template <typename T>
class ElementwiseAddGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// NOTE(zhiqiu): It seems Ascend Sub follow the broadcast sematics with
// default axis=-1?
// So, the sub_grad should do reduce if needed.
// For example, the shape of each variable in elementwise_sub:
// x, dx: [2, 3, 5]
// y, dy: [1, 5]
// out, dout: [2, 3, 5]
// Then, out = x - y => dx = dout, dy = -dout
// And, the shape of dy can be computed by two stages reduce,
// 1. [2, 3, 5] => [3, 5], ReduceSumD on axis = 0, keep_dims = false.
// 2. [3, 5] => [1, 5], ReduceSumD on axis = 0, keep_dims = true.
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis);
auto stream = dev_ctx.stream();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
// For dx
// stage 1
auto reduce_ndim = dout->dims().size() - dx->dims().size();
std::vector<int> axes;
for (auto i = 0; i < reduce_ndim; ++i) {
axes.push_back(i);
}
Tensor* tmp_dout = const_cast<Tensor*>(dout);
Tensor reduced_dout(dx->type());
if (axes.size() != 0) {
std::vector<int64_t> reduced_dout_dims;
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
reduced_dout_dims.push_back(dout->dims()[i]);
if (dx->dims() != dout->dims()) {
std::vector<int> dst_dims_vec;
std::vector<int> reduce_axes;
auto src_dims = dx->dims();
auto dout_dims = dout->dims();
int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + src_dims.size()) ||
(dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
} else {
dst_dims_vec.push_back(dout_dims[ax]);
}
}
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
reduced_dout.mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
tmp_dout = &reduced_dout;
}
// stage 2
axes.clear();
for (auto i = 0; i < dx->dims().size(); ++i) {
if (dx->dims()[i] == 1) {
axes.push_back(i);
if (!reduce_axes.empty()) {
Tensor tmp;
tmp.ShareDataWith(*dx);
tmp.Resize(framework::make_ddim(dst_dims_vec));
const auto& runner =
NpuOpRunner("ReduceSumD", {*dout}, {tmp},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
}
if (axes.size() != 0) {
const auto& runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dx},
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
framework::TensorCopy(
*tmp_dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
framework::TensorCopy(*dout, ctx.GetPlace(), dev_ctx, dx);
}
}
if (dy) {
// For dy
// stage 1
auto reduce_ndim = dout->dims().size() - dy->dims().size();
std::vector<int> axes;
for (auto i = 0; i < reduce_ndim; ++i) {
axes.push_back(i);
}
Tensor* tmp_dout = const_cast<Tensor*>(dout);
Tensor reduced_dout(dout->type());
if (axes.size() != 0) {
std::vector<int64_t> reduced_dout_dims;
for (auto i = reduce_ndim; i < dout->dims().size(); ++i) {
reduced_dout_dims.push_back(dout->dims()[i]);
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() != dout->dims()) {
std::vector<int> dst_dims_vec;
std::vector<int> reduce_axes;
auto src_dims = dy->dims();
auto dout_dims = dout->dims();
int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + src_dims.size()) ||
(dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
} else {
dst_dims_vec.push_back(dout_dims[ax]);
}
}
reduced_dout.Resize(framework::make_ddim(reduced_dout_dims));
reduced_dout.mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
tmp_dout = &reduced_dout;
}
// stage 2
axes.clear();
for (auto i = 0; i < dy->dims().size(); ++i) {
if (dy->dims()[i] == 1) {
axes.push_back(i);
if (!reduce_axes.empty()) {
Tensor tmp;
tmp.ShareDataWith(*dy);
tmp.Resize(framework::make_ddim(dst_dims_vec));
const auto& runner =
NpuOpRunner("ReduceSumD", {*dout}, {tmp},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
}
if (axes.size() != 0) {
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("ReduceSumD", {*tmp_dout}, {*dy},
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
framework::TensorCopy(
*tmp_dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
framework::TensorCopy(*dout, ctx.GetPlace(), dev_ctx, dy);
}
}
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
void NpuBroadcast(const platform::NPUDeviceContext& dev_ctx, const Tensor* src,
int axis, const framework::DDim& dst_dims,
Tensor* transformed_src) {
auto stream = dev_ctx.stream();
// 1. expand the axis with dim 1
auto src_dims = src->dims();
Tensor tmp_src;
tmp_src.ShareDataWith(*src);
tmp_src.Resize(src_dims);
for (int i = 0; i < src_dims.size(); ++i) {
if (src_dims[i] == 1 && dst_dims[i + axis] > 1) {
Tensor tmp_tensor;
auto tmp_tensor_dims = tmp_src.dims();
tmp_tensor_dims[i] = dst_dims[i + axis];
tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
const auto& runner =
NpuOpRunner("TileWithAxis", {tmp_src}, {tmp_tensor},
{{"axis", static_cast<int64_t>(i)},
{"tiles", static_cast<int64_t>(dst_dims[i + axis])}});
runner.Run(stream);
tmp_src.ShareDataWith(tmp_tensor);
tmp_src.Resize(tmp_tensor_dims);
}
}
// 2.expand the ahead axis
auto prev = framework::product(framework::slice_ddim(dst_dims, 0, axis));
if (prev > 1) {
Tensor tmp_tensor;
auto tmp_tensor_dims =
framework::slice_ddim(dst_dims, 0, axis + src_dims.size());
tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
const auto& runner = NpuOpRunner(
"ExpandD", {tmp_src}, {tmp_tensor},
{{"shape", framework::vectorize<int64_t>(tmp_tensor_dims)}});
runner.Run(stream);
tmp_src.ShareDataWith(tmp_tensor);
tmp_src.Resize(tmp_tensor_dims);
} else {
tmp_src.Resize(framework::slice_ddim(dst_dims, 0, axis + src_dims.size()));
}
// 3.expand the tail axis
auto post = framework::product(
framework::slice_ddim(dst_dims, axis + src_dims.size(), dst_dims.size()));
if (post > 1) {
auto src_dims_vec = framework::vectorize<int>(tmp_src.dims());
src_dims_vec.push_back(1);
tmp_src.Resize(framework::make_ddim(src_dims_vec));
Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(dst_dims, dev_ctx.GetPlace());
const auto& runner =
NpuOpRunner("TileWithAxis", {tmp_src}, {tmp_tensor},
{{"axis", static_cast<int64_t>(axis + src_dims.size())},
{"tiles", static_cast<int64_t>(post)}});
runner.Run(stream);
tmp_src.ShareDataWith(tmp_tensor);
}
tmp_src.Resize(dst_dims);
framework::TensorCopy(tmp_src, dev_ctx.GetPlace(), transformed_src);
}
template <typename T>
void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
const Tensor* x, const Tensor* y, int axis,
Tensor* transformed_x, Tensor* transformed_y) {
auto x_dims = x->dims();
auto y_dims = y->dims();
bool is_xsize_larger = true;
int max_dim = x_dims.size();
std::vector<int> dst_dims_vec = framework::vectorize<int>(x_dims);
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
dst_dims_vec = framework::vectorize<int>(y_dims);
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
int x_axis = is_xsize_larger ? 0 : axis;
int y_axis = is_xsize_larger ? axis : 0;
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim, axis));
for (int i = 0; i < x_dims.size(); ++i) {
dst_dims_vec[i + x_axis] =
std::max(dst_dims_vec[i + x_axis], static_cast<int>(x_dims[i]));
}
for (int i = 0; i < y_dims.size(); ++i) {
dst_dims_vec[i + y_axis] =
std::max(dst_dims_vec[i + y_axis], static_cast<int>(y_dims[i]));
}
auto dst_dims = framework::make_ddim(dst_dims_vec);
NpuBroadcast<T>(dev_ctx, x, x_axis, dst_dims, transformed_x);
NpuBroadcast<T>(dev_ctx, y, y_axis, dst_dims, transformed_y);
}
} // namespace operators
} // namespace paddle
......@@ -13,14 +13,16 @@
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest, _set_use_system_allocator
import paddle
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle
from op_test import OpTest, skip_check_grad_ci
paddle.enable_static()
......@@ -63,6 +65,9 @@ class TestElementwiseAddOp(OpTest):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place,
['X', 'Y'],
......@@ -70,6 +75,9 @@ class TestElementwiseAddOp(OpTest):
max_relative_error=0.006, )
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place,
['Y'],
......@@ -78,6 +86,9 @@ class TestElementwiseAddOp(OpTest):
max_relative_error=0.006, )
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place,
['X'],
......@@ -86,6 +97,47 @@ class TestElementwiseAddOp(OpTest):
max_relative_error=0.006, )
class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
def init_dtype(self):
self.dtype = np.float16
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseAddOp_scalar(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1,1) to test broadcast.")
class TestElementwiseAddOp_scalar2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1,1) to test broadcast.")
class TestFP16ElementwiseAddOp_scalar2(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestAddAPI(unittest.TestCase):
def test_name(self):
with paddle.static.program_guard(paddle.static.Program()):
......@@ -148,5 +200,385 @@ class TestAddError(unittest.TestCase):
self.assertRaises(TypeError, paddle.add, x2, y2)
class TestElementwiseAddOp_Vector(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
self.y = np.random.random((100, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestFP16ElementwiseAddOp_Vector(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
self.y = np.random.random((100, )).astype(self.dtype)
self.out = np.add(self.x, self.y)
class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1)
def init_axis(self):
self.axis = 0
class TestFP16ElementwiseAddOp_broadcast_0(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1)
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 100, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 100, 1)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_broadcast_1(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 100, 3).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 100, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 100)
class TestFP16ElementwiseAddOp_broadcast_2(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(100).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1, 100)
class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12, 1).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12, 1)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_broadcast_3(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 1, 2).astype(self.dtype)
self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1, 1)
def init_axis(self):
self.axis = 0
class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 1, 2).astype(self.dtype)
self.y = np.random.rand(100, 1).astype(self.dtype)
self.out = self.x + self.y.reshape(100, 1, 1, 1)
def init_axis(self):
self.axis = 0
class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 3, 12).astype(self.dtype)
self.y = np.random.rand(10, 1, 12).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_5(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 3, 12).astype(self.dtype)
self.y = np.random.rand(10, 1, 12).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)
self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12)
def init_axis(self):
self.axis = 1
class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 10, 12).astype(self.dtype)
self.y = np.random.rand(10, 12).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 10, 12)
def init_axis(self):
self.axis = 1
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
self.out = self.x + self.y.reshape(1, 1)
def init_axis(self):
self.axis = 1
class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(100, 2, 3).astype(self.dtype)
self.y = np.random.rand(100, 1, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(1, 1, 100).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseFP16AddOp_commonuse_add1(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 100).astype(self.dtype)
self.y = np.random.rand(1, 1, 100).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype)
self.y = np.random.rand(10, 1, 12, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 12).astype(self.dtype)
self.y = np.random.rand(2, 2, 10, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 2
class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(10, 1, 12).astype(self.dtype)
self.y = np.random.rand(10, 2, 12).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0))
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0))
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = self._executed_api(x, y, name='add_res')
self.assertEqual(('add_res' in y_1.name), True)
def test_declarative(self):
with fluid.program_guard(fluid.Program()):
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = self._executed_api(x, y)
place = fluid.NPUPlace(0)
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
z_expected = np.array([3., 8., 6.])
self.assertEqual((z_value == z_expected).all(), True)
def test_dygraph(self):
with fluid.dygraph.guard(paddle.NPUPlace(0)):
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = self._executed_api(x, y)
np_z = z.numpy()
z_expected = np.array([3., 8., 6.])
self.assertEqual((np_z == z_expected).all(), True)
class TestAddInplaceApi(TestAddApi):
def _executed_api(self, x, y, name=None):
return x.add_(y, name)
class TestAddInplaceBroadcastSuccess(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
self.y_numpy = np.random.rand(3, 4).astype('float')
def test_broadcast_success(self):
paddle.disable_static(place=paddle.NPUPlace(0))
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
inplace_result = x.add_(y)
numpy_result = self.x_numpy + self.y_numpy
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
paddle.enable_static()
class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
self.y_numpy = np.random.rand(3, 1).astype('float')
class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
class TestAddInplaceBroadcastError(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(3, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
def test_broadcast_errors(self):
paddle.disable_static(place=paddle.NPUPlace(0))
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
def broadcast_shape_error():
x.add_(y)
self.assertRaises(ValueError, broadcast_shape_error)
paddle.enable_static()
class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册