// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { namespace operators { template class ArgMaxMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); auto axis = static_cast(ctx.Attr("axis")); auto dtype = ctx.Attr("dtype"); const bool& flatten = ctx.Attr("flatten"); if (x->numel() == 0) return; PADDLE_ENFORCE_EQ( (dtype == 2 || dtype == 3), true, platform::errors::InvalidArgument( "The attribute of dtype in argmax op must be [%s] or [%s], " "but " "received [%s]", paddle::framework::DataTypeToString( framework::proto::VarType::INT64), paddle::framework::DataTypeToString( framework::proto::VarType::INT32), paddle::framework::DataTypeToString( static_cast(dtype)))); if (axis < 0) { framework::DDim x_dims; x_dims = x->dims(); axis += x_dims.size(); } framework::Tensor flatten_x(x->type()); flatten_x.ShareDataWith(*x); if (flatten) { flatten_x.Resize(phi::make_ddim({x->numel()})); // if flatten, the axis just as 0 axis = 0; } std::vector reduce_dims; reduce_dims.push_back(axis); auto out_dims = out->dims(); int out_count = out_dims[0]; for (int i = 1; i < out_dims.size(); i++) { out_count = out_count * out_dims[i]; } size_t indices_size_inbytes = out_count * sizeof(int32_t); auto& dev_ctx = ctx.template device_context(); framework::Tensor value_out = ctx.AllocateTmpTensor(out->dims(), dev_ctx); MLUCnnlTensorDesc value_out_desc(value_out); MLUCnnlTensorDesc input_desc( flatten_x, CNNL_LAYOUT_ARRAY, ToCnnlDataType(flatten_x.dtype())); MLUCnnlReduceDesc reduction_desc(reduce_dims, CNNL_REDUCE_MAX, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_ONLY_INDICES, CNNL_32BIT_INDICES); if (dtype == 2) { out->template mutable_data(ctx.GetPlace()); MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), nullptr, input_desc.get(), GetBasePtr(&flatten_x), indices_size_inbytes /*indices_size*/, GetBasePtr(out), nullptr, value_out_desc.get(), GetBasePtr(&value_out)); } else { out->template mutable_data(ctx.GetPlace()); framework::Tensor out_int32 = ctx.AllocateTmpTensor(out->dims(), dev_ctx); MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), nullptr, input_desc.get(), GetBasePtr(&flatten_x), indices_size_inbytes /*indices_size*/, GetBasePtr(&out_int32), nullptr, value_out_desc.get(), GetBasePtr(&value_out)); // cast indices type to int64 MLUCnnlTensorDesc out_int32_desc(out_int32); MLUCnnlTensorDesc cast_output_desc(*out); cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); MLUCnnl::Cast(ctx, cast_type, out_int32_desc.get(), GetBasePtr(&out_int32), cast_output_desc.get(), GetBasePtr(out)); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_MLU_KERNEL(arg_max, ops::ArgMaxMLUKernel, ops::ArgMaxMLUKernel, ops::ArgMaxMLUKernel);