未验证 提交 c396ee65 编写于 作者: 努力努力在努力丶's avatar 努力努力在努力丶 提交者: GitHub

[MLU]add mlu op interface (#38241)

* [MLU]add mlu op interface

* [MLU]fix alpha of activation op
上级 572b3e90
...@@ -27,40 +27,37 @@ namespace operators { ...@@ -27,40 +27,37 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename DeviceContext, cnnlActivationMode_t act_mode, typename T> template <cnnlActivationMode_t act_mode, typename T>
class ActivationMLUKernel : public framework::OpKernel<T> { class ActivationMLUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(act_mode, alpha_); MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->type())); ToCnnlDataType(input->type()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->type())); ToCnnlDataType(output->type()));
MLUCnnl::Active(dev_ctx, act_desc.get(), input_desc.get(), MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()), reinterpret_cast<const void*>(input->data<T>()),
output_desc.get(), output_desc.get(),
reinterpret_cast<void*>(output->data<T>())); reinterpret_cast<void*>(output->data<T>()));
} }
private:
float alpha_ = 1.0;
}; };
template <typename DeviceContext, cnnlActivationMode_t act_mode, typename T> template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernel : public framework::OpKernel<T> { class ActivationGradMLUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out"); auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
...@@ -70,16 +67,13 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> { ...@@ -70,16 +67,13 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
ToCnnlDataType(out->type())); ToCnnlDataType(out->type()));
MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY, MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dx->type())); ToCnnlDataType(dx->type()));
MLUCnnlActivationDesc act_desc(act_mode, alpha_); MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad( MLUCnnl::ActiveGrad(
dev_ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()), dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()),
out_desc.get(), reinterpret_cast<const void*>(out->data<T>()), out_desc.get(), reinterpret_cast<const void*>(out->data<T>()),
dx_desc.get(), reinterpret_cast<void*>(dx->data<T>())); dx_desc.get(), reinterpret_cast<void*>(dx->data<T>()));
} }
private:
float alpha_ = 1.0;
}; };
} // namespace operators } // namespace operators
...@@ -88,13 +82,9 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> { ...@@ -88,13 +82,9 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL( REGISTER_OP_MLU_KERNEL(
relu, ops::ActivationMLUKernel<paddle::platform::MLUDeviceContext, relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, float>,
CNNL_ACTIVATION_RELU, float>, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, paddle::platform::float16>);
ops::ActivationMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL( REGISTER_OP_MLU_KERNEL(
relu_grad, ops::ActivationGradMLUKernel<paddle::platform::MLUDeviceContext, relu_grad, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU, float>,
CNNL_ACTIVATION_RELU, float>, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU,
ops::ActivationGradMLUKernel<paddle::platform::MLUDeviceContext,
CNNL_ACTIVATION_RELU,
paddle::platform::float16>); paddle::platform::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册