From 9f06069d4663ea7b7cb989d1d93d02d0550e9282 Mon Sep 17 00:00:00 2001 From: qipengh Date: Mon, 18 Apr 2022 10:10:29 +0800 Subject: [PATCH] [MLU]add op: reduce_sum, elementwise_sub (#41697) * [MLU]add op: reduce_sum, elementwise_sub * [MLU]del unrelated code --- .../elementwise/elementwise_add_op_mlu.cc | 69 +----- .../operators/elementwise/elementwise_mlu.h | 207 +++++++++++++++++ .../elementwise/elementwise_mul_op_mlu.cc | 47 +--- .../elementwise/elementwise_sub_op_mlu.cc | 112 ++++++++++ paddle/fluid/operators/mlu/mlu_baseop.h | 16 ++ .../reduce_ops/reduce_mean_op_mlu.cc | 41 +--- .../operators/reduce_ops/reduce_op_mlu.h | 73 ++++++ .../operators/reduce_ops/reduce_sum_op_mlu.cc | 78 +++++++ .../mlu/test_elementwise_sub_op_mlu.py | 208 ++++++++++++++++++ .../unittests/mlu/test_reduce_sum_op_mlu.py | 149 +++++++++++++ 10 files changed, 853 insertions(+), 147 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_mlu.h create mode 100644 paddle/fluid/operators/elementwise/elementwise_sub_op_mlu.cc create mode 100644 paddle/fluid/operators/reduce_ops/reduce_op_mlu.h create mode 100644 paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_elementwise_sub_op_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc index 47a549dfcd..98d559df23 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" namespace paddle { namespace operators { @@ -23,35 +22,7 @@ template class ElementwiseAddMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - int axis = ctx.Attr("axis"); - const auto& x_dims = x->dims(); - const auto& y_dims = y->dims(); - axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) - : axis); - int max_dim = std::max(x_dims.size(), y_dims.size()); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), - y_dims_array.data(), out_dims_array.data(), max_dim, - axis); - - MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), - ToCnnlDataType(x->type())); - MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), - ToCnnlDataType(y->type())); - MLUCnnlTensorDesc out_desc(*out); - MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_ADD, ToCnnlDataType(), - CNNL_NOT_PROPAGATE_NAN); - - MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), x_desc.get(), GetBasePtr(x), - y_desc.get(), GetBasePtr(y), out_desc.get(), - GetBasePtr(out), ToCnnlDataType()); + MLUOpTensorKernel(ctx, CNNL_OP_TENSOR_ADD); } }; @@ -75,22 +46,8 @@ class ElementwiseAddGradMLUKernel : public framework::OpKernel { if (dx->dims() != dout->dims()) { std::vector dst_dims_vec; std::vector reduce_axes; - auto src_dims = dx->dims(); - auto dout_dims = dout->dims(); - - int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0); - for (int ax = 0; ax < dout_dims.size(); ++ax) { - if ((ax < src_axis || ax >= src_axis + src_dims.size()) || - (dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) { - reduce_axes.push_back(ax); - } else { - dst_dims_vec.push_back(dout_dims[ax]); - } - } - if (dst_dims_vec.size() == 0) { - // x is scalar - dst_dims_vec.push_back(1); - } + GetReduceAxesAndDstDims(axis, dout->dims(), dx->dims(), &reduce_axes, + &dst_dims_vec); MLUCnnlReduceDesc reduction_desc( reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), @@ -109,22 +66,8 @@ class ElementwiseAddGradMLUKernel : public framework::OpKernel { if (dy->dims() != dout->dims()) { std::vector dst_dims_vec; std::vector reduce_axes; - auto src_dims = dy->dims(); - auto dout_dims = dout->dims(); - - int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0); - for (int ax = 0; ax < dout_dims.size(); ++ax) { - if ((ax < src_axis || ax >= src_axis + src_dims.size()) || - (dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) { - reduce_axes.push_back(ax); - } else { - dst_dims_vec.push_back(dout_dims[ax]); - } - } - if (dst_dims_vec.size() == 0) { - // y is scalar - dst_dims_vec.push_back(1); - } + GetReduceAxesAndDstDims(axis, dout->dims(), dy->dims(), &reduce_axes, + &dst_dims_vec); MLUCnnlReduceDesc reduction_desc( reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), diff --git a/paddle/fluid/operators/elementwise/elementwise_mlu.h b/paddle/fluid/operators/elementwise/elementwise_mlu.h new file mode 100644 index 0000000000..156cea81c0 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_mlu.h @@ -0,0 +1,207 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_MLU +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +inline void GetReduceAxes(const int axis, const framework::DDim& src_ddims, + const framework::DDim& target_ddims, + std::vector* axes) { + int64_t src_dim_size = src_ddims.size(); + int64_t target_dim_size = target_ddims.size(); + for (int64_t i = 0; i < src_dim_size; ++i) { + if (i < axis || i >= target_dim_size + axis) { + axes->push_back(i); + continue; + } + if (src_ddims[i] > target_ddims[i - axis]) { + axes->push_back(i); + } + } +} + +inline void GetReduceAxesAndDstDims(const int axis, + const framework::DDim& src_ddims, + const framework::DDim& target_ddims, + std::vector* reduce_axes, + std::vector* dst_dims_vec) { + int64_t src_dim_size = src_ddims.size(); + int64_t target_dim_size = target_ddims.size(); + + int src_axis = (target_dim_size < src_dim_size ? axis : 0); + for (int ax = 0; ax < src_dim_size; ++ax) { + if ((ax < src_axis || ax >= src_axis + target_dim_size) || + (src_ddims[ax] > 1 && target_ddims[ax - src_axis] == 1)) { + reduce_axes->push_back(ax); + } else { + dst_dims_vec->push_back(src_ddims[ax]); + } + } + if (dst_dims_vec->size() == 0) { + // target_var is scalar + dst_dims_vec->push_back(1); + } +} + +template +void MLUOpTensorKernel(const framework::ExecutionContext& ctx, + const cnnlOpTensorDesc_t op_tensor_op) { + PADDLE_ENFORCE_EQ( + platform::is_mlu_place(ctx.GetPlace()), true, + platform::errors::Unavailable("This kernel only runs on MLU.")); + PADDLE_ENFORCE_EQ((op_tensor_op == CNNL_OP_TENSOR_ADD) || + (op_tensor_op == CNNL_OP_TENSOR_SUB) || + (op_tensor_op == CNNL_OP_TENSOR_MUL), + true, + platform::errors::Unavailable( + "This kernel of MLU only support ADD, SUB, MUL.")); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int axis = ctx.Attr("axis"); + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + axis = + (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), max_dim, + axis); + + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlOpTensorDesc op_tensor_desc(op_tensor_op, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), x_desc.get(), GetBasePtr(x), + y_desc.get(), GetBasePtr(y), out_desc.get(), + GetBasePtr(out), ToCnnlDataType()); +} + +// ------------------ BinaryOp ----------------- +enum BINARY_FUNCTOR { + DIV, + DIVNONAN, +}; + +template +void MLUBinary(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t x_desc, const void* x, + const cnnlTensorDescriptor_t y_desc, const void* y, + const cnnlTensorDescriptor_t out_desc, void* out); + +template <> +inline void MLUBinary
(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t x_desc, const void* x, + const cnnlTensorDescriptor_t y_desc, const void* y, + const cnnlTensorDescriptor_t out_desc, void* out) { + MLUCnnl::Div(ctx, prefer, x_desc, x, y_desc, y, out_desc, out); +} + +template +void MLUBinaryOp(const framework::ExecutionContext& ctx) { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int axis = ctx.Attr("axis"); + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + axis = + (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), max_dim, + axis); + + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + cnnlComputationPreference_t prefer_type = CNNL_COMPUTATION_HIGH_PRECISION; + MLUBinary(ctx, prefer_type, x_desc.get(), GetBasePtr(x), + y_desc.get(), GetBasePtr(y), out_desc.get(), + GetBasePtr(out)); +} + +// ------------------ UnaryOp ----------------- +enum UNARY_FUNCTOR { + NEG, + RECIPROCAL, +}; + +template +void MLUUnary(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t ouput_desc, void* output); + +template <> +inline void MLUUnary(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t output_desc, + void* output) { + MLUCnnl::Neg(ctx, input_desc, input, output_desc, output); +} + +template <> +inline void MLUUnary(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t output_desc, + void* output) { + MLUCnnl::Reciprocal(ctx, input_desc, input, output_desc, output); +} + +template +void MLUUnaryOp(const framework::ExecutionContext& ctx) { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(x, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + cnnlComputationPreference_t prefer_type = CNNL_COMPUTATION_HIGH_PRECISION; + MLUUnary(ctx, prefer_type, x_desc.get(), GetBasePtr(x), + out_desc.get(), GetBasePtr(out)); +} + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_mlu.cc index a7505890f4..33603fd73f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op_mlu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_mlu.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" -#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" namespace paddle { namespace operators { @@ -21,53 +20,11 @@ namespace operators { using Tensor = framework::Tensor; using MLUDeviceContext = platform::MLUDeviceContext; -static void GetReduceAxes(const int axis, const framework::DDim& src_ddims, - const framework::DDim& target_ddims, - std::vector* axes) { - int64_t src_dim_size = src_ddims.size(); - int64_t target_dim_size = target_ddims.size(); - for (int64_t i = 0; i < src_dim_size; ++i) { - if (i < axis || i >= target_dim_size + axis) { - axes->push_back(i); - continue; - } - if (src_ddims[i] > target_ddims[i - axis]) { - axes->push_back(i); - } - } -} - template class ElementwiseMulMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - int axis = ctx.Attr("axis"); - const auto& x_dims = x->dims(); - const auto& y_dims = y->dims(); - axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) - : axis); - int max_dim = std::max(x_dims.size(), y_dims.size()); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), - y_dims_array.data(), out_dims_array.data(), max_dim, - axis); - - MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), ToCnnlDataType()); - MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), ToCnnlDataType()); - MLUCnnlTensorDesc out_desc(*out); - MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), - CNNL_NOT_PROPAGATE_NAN); - - MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), x_desc.get(), GetBasePtr(x), - y_desc.get(), GetBasePtr(y), out_desc.get(), - GetBasePtr(out), ToCnnlDataType()); + MLUOpTensorKernel(ctx, CNNL_OP_TENSOR_MUL); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_mlu.cc new file mode 100644 index 0000000000..7c3d09effa --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_mlu.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwiseSubMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + MLUOpTensorKernel(ctx, CNNL_OP_TENSOR_SUB); + } +}; + +template +class ElementwiseSubGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis); + + MLUCnnlTensorDesc dout_desc(*dout); + + if (dx) { + dx->mutable_data(ctx.GetPlace()); + if (dx->dims() != dout->dims()) { + std::vector dst_dims_vec; + std::vector reduce_axes; + GetReduceAxesAndDstDims(axis, dout->dims(), dx->dims(), &reduce_axes, + &dst_dims_vec); + + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dx_desc(dst_dims_vec.size(), dst_dims_vec.data(), + ToCnnlDataType()); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(dout), 0, nullptr, + nullptr, dx_desc.get(), GetBasePtr(dx)); + } else { + framework::TensorCopy(*dout, ctx.GetPlace(), dev_ctx, dx); + } + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + Tensor* tmp_dout = const_cast(dout); + if (dy->dims() != dout->dims()) { + std::vector dst_dims_vec; + std::vector reduce_axes; + GetReduceAxesAndDstDims(axis, dout->dims(), dy->dims(), &reduce_axes, + &dst_dims_vec); + + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dy_desc(dst_dims_vec.size(), dst_dims_vec.data(), + ToCnnlDataType()); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(dout), 0, nullptr, + nullptr, dy_desc.get(), GetBasePtr(dy)); + tmp_dout = dy; + } + + // call neg op, dy = -dout + MLUCnnlTensorDesc tmp_dout_desc(*tmp_dout); + MLUCnnlTensorDesc dy_desc(*dy); + + MLUUnary(ctx, CNNL_COMPUTATION_HIGH_PRECISION, tmp_dout_desc.get(), + GetBasePtr(tmp_dout), dy_desc.get(), GetBasePtr(dy)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(elementwise_sub, ops::ElementwiseSubMLUKernel, + ops::ElementwiseSubMLUKernel, + ops::ElementwiseSubMLUKernel); + +REGISTER_OP_MLU_KERNEL(elementwise_sub_grad, + ops::ElementwiseSubGradMLUKernel, + ops::ElementwiseSubGradMLUKernel, + ops::ElementwiseSubGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 00ad618329..9948c45e24 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -45,6 +45,22 @@ enum MLULogicMethod { CNNL_LOGIC_OP_OR = 7, }; +const std::map MLUReduceOpMap = { + {"reduce_all", CNNL_REDUCE_AND}, {"reduce_any", CNNL_REDUCE_OR}, + {"reduce_max", CNNL_REDUCE_MAX}, {"reduce_mean", CNNL_REDUCE_AVG}, + {"reduce_min", CNNL_REDUCE_MIN}, {"reduce_sum", CNNL_REDUCE_ADD}, + {"reduce_prod", CNNL_REDUCE_MUL}, +}; + +inline cnnlReduceOp_t GetMLUCnnlReduceOp(const std::string reduce_name) { + auto iter = MLUReduceOpMap.find(reduce_name); + if (iter != MLUReduceOpMap.end()) { + return iter->second; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "Not support reduce op type of MLU Device: %s", reduce_name)); +} + inline const void* GetBasePtr(const Tensor* t) { return t->data(); } inline void* GetBasePtr(Tensor* t) { return t->data(); } diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc index 89e578dbdb..6e5fd59c45 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc @@ -12,9 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h" -#include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/platform/device/mlu/device_context.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op_mlu.h" namespace paddle { namespace operators { @@ -23,42 +21,7 @@ template class ReduceMeanMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - - bool reduce_all = context.Attr("reduce_all"); - auto dims = context.Attr>("dim"); - auto input_dims = phi::vectorize(input->dims()); - const auto& input_dim_size = input->dims().size(); - std::vector reduce_dims; - if (reduce_all) { - for (size_t i = 0; i < input_dims.size(); i++) { - reduce_dims.push_back(static_cast(i)); - } - } else { - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) { - reduce_dims.push_back(dims[i] + input_dim_size); - } else { - reduce_dims.push_back(dims[i]); - } - } - } - - MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(input->dtype())); - MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(output->dtype())); - - MLUCnnlReduceDesc reduction_desc( - reduce_dims, CNNL_REDUCE_AVG, ToCnnlDataType(), - CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); - - MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), - nullptr, input_desc.get(), GetBasePtr(input), - 0 /*indices_size*/, nullptr, nullptr, output_desc.get(), - GetBasePtr(output)); + MLUReduceOp(context, "reduce_mean"); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op_mlu.h b/paddle/fluid/operators/reduce_ops/reduce_op_mlu.h new file mode 100644 index 0000000000..95dda354ca --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_op_mlu.h @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_MLU +#include +#include +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" + +namespace paddle { +namespace operators { + +template +void MLUReduceOp(const framework::ExecutionContext& context, + std::string reduce_name) { + PADDLE_ENFORCE_EQ( + platform::is_mlu_place(context.GetPlace()), true, + platform::errors::Unavailable("This kernel only runs on MLU.")); + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + + bool reduce_all = context.Attr("reduce_all"); + auto dims = context.Attr>("dim"); + auto input_dims = phi::vectorize(input->dims()); + const auto& input_dim_size = input->dims().size(); + std::vector reduce_dims; + if (reduce_all) { + for (size_t i = 0; i < input_dims.size(); i++) { + reduce_dims.push_back(static_cast(i)); + } + } else { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + reduce_dims.push_back(dims[i] + input_dim_size); + } else { + reduce_dims.push_back(dims[i]); + } + } + } + + MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(input->dtype())); + MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(output->dtype())); + + cnnlReduceOp_t reduce_op = GetMLUCnnlReduceOp(reduce_name); + MLUCnnlReduceDesc reduction_desc(reduce_dims, reduce_op, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(input), + 0 /*indices_size*/, nullptr, nullptr, output_desc.get(), + GetBasePtr(output)); +} + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc new file mode 100644 index 0000000000..fab8bb23b1 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reduce_ops/reduce_op_mlu.h" + +namespace paddle { +namespace operators { + +template +class ReduceSumMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + MLUReduceOp(context, "reduce_sum"); + } +}; + +template +class ReduceSumGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto* in_grad = context.Output(framework::GradVarName("X")); + in_grad->mutable_data(context.GetPlace()); + + bool reduce_all = context.Attr("reduce_all"); + auto reduce_dims = context.Attr>("dim"); + auto in_dims = phi::vectorize(in->dims()); + + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < in_dims.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + for (auto& d : reduce_dims) { + if (d < 0) { + d = d + in_dims.size(); + } + } + + Tensor tmp_out(out_grad->dtype()); + auto tmp_output_dims = in_dims; + for (auto d : reduce_dims) { + tmp_output_dims[d] = 1; + } + tmp_out.ShareDataWith(*out_grad); + tmp_out.Resize(phi::make_ddim(tmp_output_dims)); + + MLUCnnlTensorDesc out_desc(tmp_out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc in_grad_desc(*in_grad, CNNL_LAYOUT_ARRAY, + ToCnnlDataType()); + + MLUCnnl::BroadcastTo(context, out_desc.get(), GetBasePtr(&tmp_out), + in_grad_desc.get(), GetBasePtr(in_grad)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(reduce_sum, ops::ReduceSumMLUKernel, + ops::ReduceSumMLUKernel); +REGISTER_OP_MLU_KERNEL(reduce_sum_grad, ops::ReduceSumGradMLUKernel, + ops::ReduceSumGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_elementwise_sub_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_sub_op_mlu.py new file mode 100644 index 0000000000..9ca5359e05 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_sub_op_mlu.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + +SEED = 2022 + + +class TestElementwiseSubOp(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.init_dtype() + self.init_input_output() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.subtract(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float32 + + def init_axis(self): + self.axis = 0 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + max_relative_error=0.005, + no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad_with_place( + self.place, ['X'], + 'Out', + max_relative_error=0.005, + no_grad_set=set('Y')) + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseSubOp_scalar(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(10, 3, 4).astype(np.float32), + 'Y': np.random.rand(1).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_Vector(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.random((100, )).astype("float32"), + 'Y': np.random.random((100, )).astype("float32") + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_broadcast_0(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(100, 3, 2).astype(np.float32), + 'Y': np.random.rand(100).astype(np.float32) + } + self.attrs = {'axis': 0} + self.outputs = { + 'Out': self.inputs['X'] - self.inputs['Y'].reshape(100, 1, 1) + } + + +class TestElementwiseSubOp_broadcast_1(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 100, 3).astype(np.float32), + 'Y': np.random.rand(100).astype(np.float32) + } + self.attrs = {'axis': 1} + self.outputs = { + 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 100, 1) + } + + +class TestElementwiseSubOp_broadcast_2(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 3, 100).astype(np.float32), + 'Y': np.random.rand(100).astype(np.float32) + } + self.outputs = { + 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 1, 100) + } + + +class TestElementwiseSubOp_broadcast_3(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 10, 12, 3).astype(np.float32), + 'Y': np.random.rand(10, 12).astype(np.float32) + } + self.attrs = {'axis': 1} + self.outputs = { + 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 10, 12, 1) + } + + +class TestElementwiseSubOp_broadcast_4(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 5, 3, 12).astype(np.float32), + 'Y': np.random.rand(2, 5, 1, 12).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_commonuse_1(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 3, 100).astype(np.float32), + 'Y': np.random.rand(1, 1, 100).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_commonuse_2(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(10, 3, 1, 4).astype(np.float32), + 'Y': np.random.rand(10, 1, 12, 1).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + +class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseSubOp): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(10, 12).astype(np.float32), + 'Y': np.random.rand(2, 3, 10, 12).astype(np.float32) + } + self.attrs = {'axis': 2} + self.outputs = { + 'Out': self.inputs['X'].reshape(1, 1, 10, 12) - self.inputs['Y'] + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py new file mode 100644 index 0000000000..d2729d77ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py @@ -0,0 +1,149 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle + +paddle.enable_static() + + +class TestMLUReduceSumOp(OpTest): + def setUp(self): + self.init_op_type() + self.initTestCase() + self.set_mlu() + self.attrs = { + 'dim': self.axis, + 'keep_dim': self.keep_dim, + 'reduce_all': self.reduce_all + } + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].sum()} + else: + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_op_type(self): + self.op_type = "reduce_sum" + self.use_mkldnn = False + self.keep_dim = False + self.reduce_all = False + + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + + +class TestSumOp5D(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (1, 2, 5, 6, 10) + self.axis = (0, ) + + +class TestSumOp6D(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (1, 1, 2, 5, 6, 10) + self.axis = (0, ) + + +class TestSumOp8D(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (1, 3, 1, 2, 1, 4, 3, 10) + self.axis = (0, 3) + + +class Test1DReduce(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = 120 + self.axis = (0, ) + + +class Test2DReduce0(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (0, ) + + +class Test2DReduce1(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (20, 10) + self.axis = (1, ) + + +class Test3DReduce0(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, ) + + +class Test3DReduce1(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (2, ) + + +class Test3DReduce2(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (-2, ) + + +class Test3DReduce3(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 7) + self.axis = (1, 2) + + +class TestKeepDimReduce(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 10) + self.axis = (1, ) + self.keep_dim = True + + +class TestKeepDim8DReduce(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (2, 5, 3, 2, 2, 3, 4, 2) + self.axis = (3, 4, 5) + self.keep_dim = True + + +class TestReduceAll(TestMLUReduceSumOp): + def initTestCase(self): + self.shape = (5, 6, 2, 10) + self.axis = (0, ) + self.reduce_all = True + + +if __name__ == '__main__': + unittest.main() -- GitLab