未验证 提交 e5478ab5 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16346 from phlrain/add_floordiv_and_mod

add elementwise floordiv, mod
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseFloorDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "FloorDiv"; }
std::string GetEquation() const override { return "Out = X // Y"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_floordiv, ops::ElementwiseOp,
ops::ElementwiseFloorDivOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFloorDivKernel<paddle::platform::CPUDeviceContext,
int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
struct FloorDivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
};
template <typename DeviceContext, typename T>
void elementwise_floor_div(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, FloorDivFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
class ElementwiseFloorDivKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is int64 or int32
elementwise_floor_div<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseModOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Mod"; }
std::string GetEquation() const override { return "Out = X % Y"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp,
ops::ElementwiseModOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
struct ModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a % b; }
};
template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is int64 or int32
elementwise_mod<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
......@@ -174,6 +174,8 @@ def monkey_patch_variable():
("__rtruediv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True),
("__floordiv__", "elementwise_floordiv", False),
("__mod__", "elementwise_mod", False),
# for logical compare
("__eq__", "equal", False),
("__ne__", "not_equal", False),
......
......@@ -9231,9 +9231,24 @@ def elementwise_pow(x, y, axis=-1, act=None, name=None):
return _elementwise_op(LayerHelper('elementwise_pow', **locals()))
def elementwise_mod(x, y, axis=-1, act=None, name=None):
return _elementwise_op(LayerHelper('elementwise_mod', **locals()))
def elementwise_floordiv(x, y, axis=-1, act=None, name=None):
return _elementwise_op(LayerHelper('elementwise_floordiv', **locals()))
for func in [
elementwise_add, elementwise_div, elementwise_sub, elementwise_mul,
elementwise_max, elementwise_min, elementwise_pow
elementwise_add,
elementwise_div,
elementwise_sub,
elementwise_mul,
elementwise_max,
elementwise_min,
elementwise_pow,
elementwise_mod,
elementwise_floordiv,
]:
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
func.__doc__ = _generate_doc_string_(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import random
class TestElementwiseModOp(OpTest):
def init_kernel_type(self):
self.use_mkldnn = False
def setUp(self):
self.op_type = "elementwise_floordiv"
self.dtype = np.int32
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)
def init_dtype(self):
pass
def init_axis(self):
pass
class TestElementwiseModOp_scalar(TestElementwiseModOp):
def init_input_output(self):
scale_x = random.randint(0, 100000000)
scale_y = random.randint(1, 100000000)
self.x = (np.random.rand(2, 3, 4) * scale_x).astype(self.dtype)
self.y = (np.random.rand(1) * scale_y + 1).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import random
class TestElementwiseModOp(OpTest):
def init_kernel_type(self):
self.use_mkldnn = False
def setUp(self):
self.op_type = "elementwise_mod"
self.dtype = np.int32
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)
def init_dtype(self):
pass
def init_axis(self):
pass
class TestElementwiseModOp_scalar(TestElementwiseModOp):
def init_input_output(self):
scale_x = random.randint(0, 100000000)
scale_y = random.randint(1, 100000000)
self.x = (np.random.rand(2, 3, 4) * scale_x).astype(self.dtype)
self.y = (np.random.rand(1) * scale_y + 1).astype(self.dtype)
self.out = np.mod(self.x, self.y)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册