未验证 提交 bd06a828 编写于 作者: Z zhaoying9105 提交者: GitHub

[MLU]: add hard_sigmoid,hard_sigmoid_grad,hard_swish,hard_swish_grad kernel (#44044)

上级 8f8a6848
......@@ -256,6 +256,149 @@ class ExpGradMLUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class HardSwishMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
PADDLE_ENFORCE_EQ(threshold,
6.0f,
platform::errors::External(
"Not support threshold [%f] in MLU", threshold));
PADDLE_ENFORCE_EQ(
scale,
6.0f,
platform::errors::External("Not support scale [%f] in MLU", scale));
PADDLE_ENFORCE_EQ(
offset,
3.0f,
platform::errors::External("Not support offset [%f] in MLU", offset));
MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSWISH,
1.0f /*ceof useless*/);
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
MLUCnnl::Active(ctx,
act_desc.get(),
input_desc.get(),
GetBasePtr(input),
output_desc.get(),
GetBasePtr(output));
}
};
template <typename T>
class HardSwishGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
PADDLE_ENFORCE_EQ(threshold,
6.0f,
platform::errors::External(
"Not support threshold [%f] in MLU", threshold));
PADDLE_ENFORCE_EQ(
scale,
6.0f,
platform::errors::External("Not support scale [%f] in MLU", scale));
PADDLE_ENFORCE_EQ(
offset,
3.0f,
platform::errors::External("Not support offset [%f] in MLU", offset));
auto* out = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSWISH,
1.0f /*ceof useless*/);
MLUCnnl::ActiveGrad(ctx,
act_desc.get(),
nullptr,
nullptr,
nullptr,
nullptr,
dout_desc.get(),
GetBasePtr(dout),
out_desc.get(),
GetBasePtr(out),
dx_desc.get(),
GetBasePtr(dx));
}
};
template <typename T>
class HardSigmoidMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
float slope = ctx.Attr<float>("slope");
float offset = ctx.Attr<float>("offset");
output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSIGMOID,
1.0f /*ceof useless*/,
1.0f /*sliced_dim useless*/,
slope,
offset);
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
MLUCnnl::Active(ctx,
act_desc.get(),
input_desc.get(),
GetBasePtr(input),
output_desc.get(),
GetBasePtr(output));
}
};
template <typename T>
class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<Tensor>("Out");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float slope = ctx.Attr<float>("slope");
float offset = ctx.Attr<float>("offset");
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(CNNL_ACTIVATION_HARDSIGMOID,
1.0f /*ceof useless*/,
1.0f /*sliced_dim useless*/,
slope,
offset);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::ActiveGrad(ctx,
act_desc.get(),
nullptr,
nullptr,
nullptr,
nullptr,
dout_desc.get(),
GetBasePtr(dout),
out_desc.get(),
GetBasePtr(out),
dx_desc.get(),
GetBasePtr(dx));
}
};
} // namespace operators
} // namespace paddle
......@@ -359,3 +502,20 @@ REGISTER_OP_MLU_KERNEL(exp,
REGISTER_OP_MLU_KERNEL(exp_grad,
ops::ExpGradMLUKernel<float>,
ops::ExpGradMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(hard_swish,
ops::HardSwishMLUKernel<float>,
ops::HardSwishMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(hard_swish_grad,
ops::HardSwishGradMLUKernel<float>,
ops::HardSwishGradMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(hard_sigmoid,
ops::HardSigmoidMLUKernel<float>,
ops::HardSigmoidMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
hard_sigmoid_grad,
ops::HardSigmoidGradMLUKernel<float>,
ops::HardSigmoidGradMLUKernel<paddle::platform::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
paddle.enable_static()
SEED = 2021
def ref_hardsigmoid(x, slope=0.166666666666667, offset=0.5):
return np.maximum(np.minimum(x * slope + offset, 1.), 0.).astype(x.dtype)
class TestMLUHardSigmoid(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "hard_sigmoid"
self.set_mlu()
self.init_dtype()
self.set_attrs()
x = np.random.uniform(-5, 5, [10, 12]).astype(self.dtype)
lower_threshold = -self.offset / self.slope
upper_threshold = (1. - self.offset) / self.slope
# Same reason as TestAbs
delta = 0.005
x[np.abs(x - lower_threshold) < delta] = lower_threshold - 0.02
x[np.abs(x - upper_threshold) < delta] = upper_threshold - 0.02
out = ref_hardsigmoid(x, self.slope, self.offset)
self.attrs = {'slope': self.slope, 'offset': self.offset}
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def set_attrs(self):
self.slope = 0.166666666666667
self.offset = 0.5
class TestMLUHardSigmoid2(TestMLUHardSigmoid):
def set_attrs(self):
self.slope = 0.2
self.offset = 0.5
class TestMLUHardSigmoid3(TestMLUHardSigmoid):
def set_attrs(self):
self.slope = 0.2
self.offset = 0.4
class TestMLUHardSigmoidFp16(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.place = paddle.MLUPlace(0)
self.dtype = np.float32
# float32
self.float32_x = np.random.uniform(-5, 5, [10, 12]).astype(np.float32)
paddle.set_device('cpu')
data = paddle.to_tensor(self.float32_x, stop_gradient=True)
self.float32_y = F.hardsigmoid(data)
# float16
self.float16_x = self.float32_x.astype(np.float16)
self.float16_y = ref_hardsigmoid(self.float16_x)
def test_check_output_and_grad_mlu(self):
# mlu float16
paddle.set_device('mlu')
data = paddle.to_tensor(self.float16_x, stop_gradient=True)
mlu_float16_y = F.hardsigmoid(data)
cpu_diff_1 = np.divide(
np.sum(np.abs(self.float32_y.numpy() - self.float16_y)),
np.sum(np.abs(self.float32_y.numpy())))
mlu_diff_1 = np.divide(
np.sum(np.abs(self.float32_y.numpy() - mlu_float16_y.numpy())),
np.sum(np.abs(self.float32_y.numpy())))
cpu_diff_2 = np.divide(
np.sum(np.square(self.float32_y.numpy() - self.float16_y)),
np.sum(np.square(self.float32_y.numpy())))
mlu_diff_2 = np.divide(
np.sum(np.square(self.float32_y.numpy() - mlu_float16_y.numpy())),
np.sum(np.square(self.float32_y.numpy())))
assert mlu_diff_1 <= cpu_diff_1
assert mlu_diff_2 <= cpu_diff_2
class TestHardsigmoidAPI(unittest.TestCase):
# test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(np.float32)
self.place = paddle.MLUPlace(0)
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
out1 = F.hardsigmoid(x)
m = paddle.nn.Hardsigmoid()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_hardsigmoid(self.x_np)
for r in res:
self.assertTrue(np.allclose(out_ref, r))
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.hardsigmoid(x)
m = paddle.nn.Hardsigmoid()
out2 = m(x)
out_ref = ref_hardsigmoid(self.x_np)
for r in [out1, out2]:
self.assertTrue(np.allclose(out_ref, r.numpy()))
paddle.enable_static()
def test_fluid_api(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = fluid.layers.hard_sigmoid(x)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_hardsigmoid(self.x_np, 0.2, 0.5)
self.assertTrue(np.allclose(out_ref, res[0]))
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out = paddle.fluid.layers.hard_sigmoid(x)
self.assertTrue(np.allclose(out_ref, out.numpy()))
paddle.enable_static()
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.hardsigmoid, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(name='x_int32',
shape=[12, 10],
dtype='int32')
self.assertRaises(TypeError, F.hardsigmoid, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(name='x_fp16',
shape=[12, 10],
dtype='float16')
F.hardsigmoid(x_fp16)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle
from op_test import OpTest
import numpy as np
import unittest
import sys
sys.path.append("..")
paddle.enable_static()
SEED = 2020
def scalarToType(val, data_type):
converted_val = np.array([val]).astype(data_type)[0]
print("converted_val type: ", type(converted_val))
return converted_val
def ref_hard_swish_grad(x, threshold, scale, offset, data_type):
threshold = scalarToType(threshold, data_type)
scale = scalarToType(scale, data_type)
offset = scalarToType(offset, data_type)
dout = np.full_like(x, fill_value=1. / x.size)
tmp = ((x + offset) < threshold).astype(x.dtype)
dx = dout * (((x + offset) > 0).astype(x.dtype) *
(2 * x + offset) * tmp / scale + 1.0 - tmp)
return dx
class TestHardSwishMLU(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "hard_swish"
self.place = paddle.MLUPlace(0)
self.init_dtype()
x = np.random.uniform(-2, 2, [10, 12]).astype(self.dtype)
threshold = 6.0
scale = 6.0
offset = 3.0
x[np.abs(x + offset) < 0.005] = 0.02
x[np.abs(x - threshold + offset) < 0.005] = threshold - offset + 0.02
out = (
x *
(np.minimum(np.maximum(x + offset, 0.), threshold) / scale)).astype(
self.dtype)
self.x_grad = ref_hard_swish_grad(x, threshold, scale, offset,
self.dtype)
self.set_mlu()
self.inputs = {'X': x}
self.attrs = {'threshold': threshold, 'scale': scale, 'offset': offset}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestHardSwishMLUWithCPUFloat16(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.place = paddle.MLUPlace(0)
self.dtype = np.float32
# float32
self.float32_x = np.random.uniform(-6, 10, [8, 15]).astype(np.float32)
paddle.set_device('cpu')
data = paddle.to_tensor(self.float32_x, stop_gradient=False)
self.float32_y = F.hardswish(data)
self.float32_y.sum().backward()
self.float32_grad = data.grad
# float16
self.float16_x = self.float32_x.astype(np.float16)
threshold = 6.0
scale = 6.0
offset = 3.0
threshold = scalarToType(threshold, np.float16)
scale = scalarToType(scale, np.float16)
offset = scalarToType(offset, np.float16)
self.float16_y = (self.float16_x * (np.minimum(
np.maximum(self.float16_x + offset, scalarToType(0., np.float16)),
threshold) / scale)).astype(np.float16)
self.float16_grad = ref_hard_swish_grad(self.float16_x, threshold,
scale, offset, np.float16)
def test_check_output_and_grad_mlu(self):
# mlu float16
paddle.set_device('mlu')
data = paddle.to_tensor(self.float16_x, stop_gradient=False)
mlu_float16_y = F.hardswish(data)
mlu_float16_y.sum().backward()
mlu_float16_grad = data.grad
cpu_diff_1 = np.divide(
np.sum(np.abs(self.float32_y.numpy() - self.float16_y)),
np.sum(np.abs(self.float32_y.numpy())))
mlu_diff_1 = np.divide(
np.sum(np.abs(self.float32_y.numpy() - mlu_float16_y.numpy())),
np.sum(np.abs(self.float32_y.numpy())))
cpu_diff_2 = np.divide(
np.sum(np.square(self.float32_y.numpy() - self.float16_y)),
np.sum(np.square(self.float32_y.numpy())))
mlu_diff_2 = np.divide(
np.sum(np.square(self.float32_y.numpy() - mlu_float16_y.numpy())),
np.sum(np.square(self.float32_y.numpy())))
assert mlu_diff_1 <= cpu_diff_1
assert mlu_diff_2 <= cpu_diff_2
cpu_diff_1 = np.divide(
np.sum(np.abs(self.float32_grad.numpy() - self.float16_grad)),
np.sum(np.abs(self.float32_grad.numpy())))
mlu_diff_1 = np.divide(
np.sum(np.abs(self.float32_grad.numpy() -
mlu_float16_grad.numpy())),
np.sum(np.abs(self.float32_grad.numpy())))
cpu_diff_2 = np.divide(
np.sum(np.square(self.float32_grad.numpy() - self.float16_grad)),
np.sum(np.square(self.float32_grad.numpy())))
mlu_diff_2 = np.divide(
np.sum(
np.square(self.float32_grad.numpy() -
mlu_float16_grad.numpy())),
np.sum(np.square(self.float32_grad.numpy())))
assert mlu_diff_1 <= cpu_diff_1
assert mlu_diff_2 <= cpu_diff_2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册