未验证 提交 c396ee65 编写于 作者: 努力努力在努力丶's avatar 努力努力在努力丶 提交者: GitHub

[MLU]add mlu op interface (#38241)

* [MLU]add mlu op interface

* [MLU]fix alpha of activation op
上级 572b3e90
......@@ -27,40 +27,37 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, cnnlActivationMode_t act_mode, typename T>
template <cnnlActivationMode_t act_mode, typename T>
class ActivationMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(act_mode, alpha_);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->type()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->type()));
MLUCnnl::Active(dev_ctx, act_desc.get(), input_desc.get(),
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()),
output_desc.get(),
reinterpret_cast<void*>(output->data<T>()));
}
private:
float alpha_ = 1.0;
};
template <typename DeviceContext, cnnlActivationMode_t act_mode, typename T>
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace());
......@@ -70,16 +67,13 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
ToCnnlDataType(out->type()));
MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dx->type()));
MLUCnnlActivationDesc act_desc(act_mode, alpha_);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(
dev_ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()),
out_desc.get(), reinterpret_cast<const void*>(out->data<T>()),
dx_desc.get(), reinterpret_cast<void*>(dx->data<T>()));
}
private:
float alpha_ = 1.0;
};
} // namespace operators
......@@ -88,13 +82,9 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(
relu, ops::ActivationMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU, float>,
ops::ActivationMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU, paddle::platform::float16>);
relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu_grad, ops::ActivationGradMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU,
relu_grad, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册