未验证 提交 a697e946 编写于 作者: W wawltor 提交者: GitHub

Update the code of the compare ops for the broadcast function

Update the code for the compare ops for the broadcast function
上级 8ec4af27
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_op.h" #include "paddle/fluid/operators/controlflow/compare_op.h"
#include <algorithm>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -85,14 +88,22 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -85,14 +88,22 @@ class CompareOp : public framework::OperatorWithKernel {
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), if (context->GetInputDim("X") == context->GetInputDim("Y")) {
platform::errors::InvalidArgument( context->ShareDim("X", /*->*/ "Out");
"The size of dim_y should not be greater than " context->ShareLoD("X", /*->*/ "Out");
"dim_x's, but received dim_y: %d > dim_x: %d.\n", } else {
dim_y.size(), dim_x.size())); int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
context->SetOutputDim("Out", context->GetInputDim("X")); std::vector<int> x_dims_array(max_dim);
context->ShareLoD("X", "Out"); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
context->SetOutputDim("Out", framework::make_ddim(out_dims_array));
// to do
context->ShareLoD("X", /*->*/ "Out");
}
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
......
...@@ -72,25 +72,40 @@ def create_paddle_case(op_type, callback): ...@@ -72,25 +72,40 @@ def create_paddle_case(op_type, callback):
class PaddleCls(unittest.TestCase): class PaddleCls(unittest.TestCase):
def setUp(self): def setUp(self):
self.op_type = op_type self.op_type = op_type
self.input_x = np.array([1, 2, 3, 4]) self.input_x = np.array([1, 2, 3, 4]).astype(np.int64)
self.input_y = np.array([1, 3, 2, 4]) self.input_y = np.array([1, 3, 2, 4]).astype(np.int64)
self.real_result = callback(self.input_x, self.input_y) self.real_result = callback(self.input_x, self.input_y)
self.place = fluid.CPUPlace()
if core.is_compiled_with_cuda():
self.place = paddle.CUDAPlace(0)
def test_api(self): def test_api(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[4], dtype='int64') x = fluid.data(name='x', shape=[4], dtype='int64')
y = fluid.layers.data(name='y', shape=[4], dtype='int64') y = fluid.data(name='y', shape=[4], dtype='int64')
op = eval("paddle.%s" % (self.op_type)) op = eval("paddle.%s" % (self.op_type))
out = op(x, y) out = op(x, y)
place = fluid.CPUPlace() exe = fluid.Executor(self.place)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
res, = exe.run(feed={"x": self.input_x, res, = exe.run(feed={"x": self.input_x,
"y": self.input_y}, "y": self.input_y},
fetch_list=[out]) fetch_list=[out])
self.assertEqual((res == self.real_result).all(), True) self.assertEqual((res == self.real_result).all(), True)
def test_broadcast_api_1(self):
with program_guard(Program(), Program()):
x = paddle.nn.data(name='x', shape=[1, 2, 1, 3], dtype='int32')
y = paddle.nn.data(name='y', shape=[1, 2, 3], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.Executor(self.place)
input_x = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32)
input_y = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self): def test_attr_name(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[4], dtype='int32') x = fluid.layers.data(name='x', shape=[4], dtype='int32')
...@@ -104,6 +119,7 @@ def create_paddle_case(op_type, callback): ...@@ -104,6 +119,7 @@ def create_paddle_case(op_type, callback):
globals()[cls_name] = PaddleCls globals()[cls_name] = PaddleCls
create_paddle_case('less_than', lambda _a, _b: _a < _b)
create_paddle_case('less_equal', lambda _a, _b: _a <= _b) create_paddle_case('less_equal', lambda _a, _b: _a <= _b)
create_paddle_case('greater_than', lambda _a, _b: _a > _b) create_paddle_case('greater_than', lambda _a, _b: _a > _b)
create_paddle_case('greater_equal', lambda _a, _b: _a >= _b) create_paddle_case('greater_equal', lambda _a, _b: _a >= _b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册