未验证 提交 10114859 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add mlu activation kernels (#41751)

上级 fc208b7e
......@@ -15,12 +15,8 @@ limitations under the Licnse. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
......@@ -38,20 +34,39 @@ class ActivationMLUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->dtype()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()),
output_desc.get(),
reinterpret_cast<void*>(output->data<T>()));
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), GetBasePtr(input),
output_desc.get(), GetBasePtr(output));
}
};
// For gelu, leaky_relu
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernel : public framework::OpKernel<T> {
class ActivationGradMLUKernelV1 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), x_desc.get(),
GetBasePtr(x), dx_desc.get(), GetBasePtr(dx));
}
};
// For tanh, sigmoid
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
......@@ -61,18 +76,35 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dout->dtype()));
MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out->dtype()));
MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dx->dtype()));
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(
ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()),
out_desc.get(), reinterpret_cast<const void*>(out->data<T>()),
dx_desc.get(), reinterpret_cast<void*>(dx->data<T>()));
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, out_desc.get(),
GetBasePtr(out), dout_desc.get(), GetBasePtr(dout),
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
}
};
// For relu, relu6
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV3 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), out_desc.get(),
GetBasePtr(out), dx_desc.get(), GetBasePtr(dx));
}
};
......@@ -81,10 +113,60 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
// relu
REGISTER_OP_MLU_KERNEL(
relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu_grad, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
relu_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
// relu6
REGISTER_OP_MLU_KERNEL(
relu6, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu6_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6,
paddle::platform::float16>);
// sigmoid
REGISTER_OP_MLU_KERNEL(sigmoid,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
sigmoid_grad,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);
// tanh
REGISTER_OP_MLU_KERNEL(
tanh, ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
tanh_grad, ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH,
paddle::platform::float16>);
// gelu
REGISTER_OP_MLU_KERNEL(
gelu, ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
gelu_grad, ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU,
paddle::platform::float16>);
// leaky_relu
REGISTER_OP_MLU_KERNEL(
leaky_relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
leaky_relu_grad,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
......@@ -51,6 +51,8 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
}
}
}
const T *value_data = &value;
cnnlPointerMode_t pointer_mode = CNNL_POINTER_MODE_HOST;
if (ctx.HasInput("ValueTensor")) {
auto *value_tensor = ctx.Input<framework::Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ(
......@@ -59,22 +61,18 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
"When use Tensor as value to set Tensor value in fill_cosntant, "
"value input(ValueTensor) size must be 1, but get %d",
value_tensor->numel()));
const T *tensor_data = value_tensor->data<T>();
framework::Tensor mlu_tensor;
value_data = value_tensor->data<T>();
auto tmp_place = value_tensor->place();
if (platform::is_mlu_place(tmp_place)) {
framework::TensorCopySync(*value_tensor, platform::CPUPlace(),
&mlu_tensor);
tensor_data = mlu_tensor.data<T>();
pointer_mode = CNNL_POINTER_MODE_DEVICE;
}
value = tensor_data[0];
}
auto shape = GetShape(ctx);
out_var->mutable_data<T>(shape, ctx.GetPlace());
MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out_var->dtype()));
MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var));
MLUCnnlTensorDesc output_desc(*out_var);
MLUCnnl::Fill(ctx, pointer_mode, value_data, output_desc.get(),
GetBasePtr(out_var));
}
};
} // namespace operators
......
......@@ -95,7 +95,8 @@ class MeanMLUGradKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(mean_var.dtype()));
auto value = static_cast<T>(1.0 / static_cast<float>(input_grad->numel()));
MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var));
MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, mean_var_desc.get(),
GetBasePtr(&mean_var));
// means mul output_grad
MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY,
......
......@@ -136,15 +136,17 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// [total]
total->mutable_data<int>(ctx.GetPlace());
MLUCnnlTensorDesc total_desc(*total);
MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples, total_desc.get(),
GetBasePtr(total));
// use `total` of type `float32` for calculating accuracy
Tensor total_fp32(framework::TransToPhiDataType(VT::FP32));
total_fp32.Resize(total->dims());
total_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc total_fp32_desc(total_fp32);
MLUCnnl::Fill(ctx, static_cast<float>(num_samples), total_fp32_desc.get(),
GetBasePtr(&total_fp32));
float num_samples_fp32 = static_cast<float>(num_samples);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples_fp32,
total_fp32_desc.get(), GetBasePtr(&total_fp32));
// [accuracy]
accuracy->mutable_data<float>(ctx.GetPlace());
......
......@@ -208,8 +208,20 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() {
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor(
active_desc_, act_mode, CNNL_NOT_PROPAGATE_NAN, ceof));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, ceof, 1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/));
}
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof,
const float sliced_dim, const float selu_alpha, const float selu_lambda) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, ceof, sliced_dim, selu_alpha, selu_lambda));
}
const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const {
......@@ -541,12 +553,15 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_desc, output));
}
/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, float value,
/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx,
const cnnlPointerMode_t pointer_mode,
const void* value_ptr,
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlFill(handle, value, output_desc, output));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlFill_v3(handle, pointer_mode, value_ptr, output_desc, output));
}
/* static */ void MLUCnnl::QuantifyOffline(
......
......@@ -218,6 +218,9 @@ class MLUCnnlActivationDesc {
MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete;
MLUCnnlActivationDesc& operator=(const MLUCnnlActivationDesc& desc) = delete;
MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof);
MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof,
const float sliced_dim, const float selu_alpha,
const float selu_lambda);
const cnnlActivationDescriptor_t get() const;
~MLUCnnlActivationDesc();
......@@ -418,7 +421,8 @@ class MLUCnnl {
const cnnlTensorDescriptor_t in1_desc, const void* in1,
const cnnlTensorDescriptor_t output_desc, void* output);
static void Fill(const ExecutionContext& ctx, float value,
static void Fill(const ExecutionContext& ctx,
const cnnlPointerMode_t pointer_mode, const void* value_ptr,
const cnnlTensorDescriptor_t output_desc, void* output);
static void LRN(const ExecutionContext& ctx, const int local_size,
......
......@@ -69,7 +69,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
"the same Tensors."));
}
auto mu = ctx.Attr<float>("mu");
auto mu = static_cast<T>(ctx.Attr<float>("mu"));
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
......@@ -114,7 +114,8 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
Tensor mu_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(),
GetBasePtr(&mu_tensor));
for (size_t idx = 0; idx < n; ++idx) {
RegularizationType regularization_flag =
......
......@@ -52,7 +52,8 @@ class MLUMomentumOpKernel : public framework::OpKernel<T> {
Tensor mu_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(),
GetBasePtr(&mu_tensor));
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param);
......
......@@ -103,8 +103,8 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> {
ToCnnlDataType(input_grad->dtype()));
auto value = static_cast<T>(1.0 / static_cast<float>(reduce_numel));
MLUCnnl::Fill(context, value, input_grad_desc.get(),
GetBasePtr(input_grad));
MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value,
input_grad_desc.get(), GetBasePtr(input_grad));
MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN);
......
......@@ -27,7 +27,7 @@ class ScaleMLUKernel : public framework::OpKernel<T> {
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
// cnnl require input, scale, bias with same type. And all in device side.
auto& scale = ctx.Attr<float>("scale");
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
framework::Tensor scale_tensor;
if (ctx.HasInput("ScaleTensor")) {
framework::Tensor float_scale_tensor =
......@@ -49,14 +49,16 @@ class ScaleMLUKernel : public framework::OpKernel<T> {
} else {
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &scale, scale_desc.get(),
GetBasePtr(&scale_tensor));
}
auto& bias = ctx.Attr<float>("bias");
auto bias = static_cast<T>(ctx.Attr<float>("bias"));
framework::Tensor bias_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc bias_desc(bias_tensor);
MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &bias, bias_desc.get(),
GetBasePtr(&bias_tensor));
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<phi::SelectedRows>() && in_var != out_var) {
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from scipy import special
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
def np_gelu(x):
y = 0.5 * x * (1 + special.erf(x / np.sqrt(2)))
return y
class TestGelu(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "gelu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np_gelu(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.007)
class TestGeluFp16(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "gelu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np_gelu(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestGeluNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
c = paddle.multiply(a, b)
fc_1 = fluid.layers.fc(input=c, size=128)
fc_1_gelu = fluid.layers.gelu(fc_1)
prediction = fluid.layers.fc(input=fc_1_gelu, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred, atol=1e-3))
self.assertTrue(np.allclose(mlu_loss, cpu_loss, atol=1e-3))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from test_activation_op import ref_leaky_relu
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestLeadyRelu(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "leaky_relu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
self.set_inputs()
self.set_attrs()
self.set_outputs()
def set_inputs(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
def set_attrs(self):
self.attrs = {}
def set_outputs(self):
alpha = 0.02 if 'alpha' not in self.attrs else self.attrs['alpha']
out = ref_leaky_relu(self.inputs['X'], alpha)
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.006)
else:
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestLeadyReluFP16(TestLeadyRelu):
def init_dtype(self):
self.dtype = np.float16
class TestLeadyRelu2(TestLeadyRelu):
def set_attrs(self):
self.attrs = {'alpha': 0.5}
class TestLeadyRelu3(TestLeadyRelu):
def set_attrs(self):
self.attrs = {'alpha': -0.5}
class TestLeakyReluNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
x_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name="x", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
y = paddle.nn.functional.leaky_relu(x)
fc_1 = fluid.layers.fc(input=y, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={"x": x_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle
from op_test import OpTest
import numpy as np
import unittest
import sys
sys.path.append("..")
paddle.enable_static()
SEED = 2021
def ref_relu6(x, threshold=6.0):
out = np.copy(x)
out[np.abs(x - threshold) < 0.005] = threshold + 0.02
out = np.minimum(np.maximum(x, 0), threshold)
return out
class TestRelu6(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "relu6"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_relu6(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': 6.0}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_dtype(self):
self.dtype = np.float32
class TestRelu6Float16(TestRelu6):
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def set_attrs(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place)
class TestReluNeg(TestRelu6):
def setUp(self):
self.set_mlu()
self.op_type = "relu6"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-10, -1, [10, 12]).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_relu6(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': 6.0}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
class TestRelu6Net(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.nn.functional.relu6(sum)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestMLUSigmoid(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.set_mlu()
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = 1 / (1 + np.exp(-x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.01)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
class TestMLUSigmoidFp16(TestMLUSigmoid):
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestTanh(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "tanh"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
self.check_grad(['X'], 'Out', max_relative_error=0.009)
else:
self.check_grad(['X'], 'Out', max_relative_error=0.009)
class TestTanhFp16(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "tanh"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestTanhNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
c = paddle.multiply(a, b)
d = paddle.tanh(c)
fc_1 = fluid.layers.fc(input=d, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册