未验证 提交 1d898669 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7593 from JiayiFeng/dev_elementwise_scalar

Make elementwise_op supporting scalar input `Y`
...@@ -19,11 +19,16 @@ limitations under the License. */ ...@@ -19,11 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct DivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, DeviceContext, T>(ctx); ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx);
} }
}; };
......
...@@ -18,11 +18,16 @@ limitations under the License. */ ...@@ -18,11 +18,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, DeviceContext, T>(ctx); ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx);
} }
}; };
......
...@@ -340,6 +340,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -340,6 +340,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
...@@ -378,6 +385,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { ...@@ -378,6 +385,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
return; return;
} }
if (y_dims.size() == 1 && y_dims[0] == 1) {
// y is a scalar
auto extended_dims = framework::vectorize(x_dims);
extended_dims.push_back(1);
x_dims = framework::make_ddim(extended_dims);
}
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
......
...@@ -18,11 +18,16 @@ limitations under the License. */ ...@@ -18,11 +18,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, DeviceContext, T>(ctx); ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx);
} }
}; };
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -40,6 +40,16 @@ class TestElementwiseOp(OpTest): ...@@ -40,6 +40,16 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseAddOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
class TestElementwiseAddOp_Vector(TestElementwiseOp): class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -45,6 +45,16 @@ class ElementwiseDivOp(OpTest): ...@@ -45,6 +45,16 @@ class ElementwiseDivOp(OpTest):
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32),
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']}
class TestElementwiseDivOp_Vector(ElementwiseDivOp): class TestElementwiseDivOp_Vector(ElementwiseDivOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
......
...@@ -43,6 +43,15 @@ class TestElementwiseOp(OpTest): ...@@ -43,6 +43,15 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseMaxOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
x = np.random.random_integers(-5, 5, [2, 3, 4]).astype("float32")
y = np.array([0.5]).astype("float32")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_Vector(TestElementwiseOp): class TestElementwiseMaxOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_max" self.op_type = "elementwise_max"
......
...@@ -43,6 +43,15 @@ class TestElementwiseOp(OpTest): ...@@ -43,6 +43,15 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseMinOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_min"
x = np.random.random_integers(-5, 5, [2, 3, 4]).astype("float32")
y = np.array([0.5]).astype("float32")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_Vector(TestElementwiseOp): class TestElementwiseMaxOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -38,6 +38,16 @@ class ElementwiseMulOp(OpTest): ...@@ -38,6 +38,16 @@ class ElementwiseMulOp(OpTest):
self.check_grad(['X'], 'Out', no_grad_set=set('Y')) self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestElementwiseMulOp_scalar(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_Vector(ElementwiseMulOp): class TestElementwiseMulOp_Vector(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
...@@ -40,6 +40,16 @@ class TestElementwiseOp(OpTest): ...@@ -40,6 +40,16 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseSubOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_Vector(TestElementwiseOp): class TestElementwiseSubOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_sub" self.op_type = "elementwise_sub"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册