From 64a7cbd3136330cc9d55de7927e9865d8554a85a Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 18 Mar 2022 17:45:27 +0800 Subject: [PATCH] [Phi]Move hierarchical_sigmoid kernel to phi (#40553) * first commit * fix compile error * support std::vector * fix * fix op support on GPU by chenweihang * pass test * infershape * add set_dtype * fix order * fix * unify the impl of dt and sr * fix --- paddle/fluid/framework/operator.cc | 25 +- paddle/fluid/imperative/prepared_operator.h | 4 + .../operators/hierarchical_sigmoid_op.cc | 55 +---- .../fluid/operators/hierarchical_sigmoid_op.h | 222 ------------------ paddle/phi/infermeta/multiary.cc | 34 +++ paddle/phi/infermeta/multiary.h | 17 ++ paddle/phi/kernels/CMakeLists.txt | 5 +- .../kernels/cpu/hierarchical_sigmoid_grad.h | 110 +++++++++ .../cpu/hierarchical_sigmoid_grad_kernel.cc | 71 ++++++ .../cpu/hierarchical_sigmoid_kernel.cc | 115 +++++++++ .../hierarchical_sigmoid_grad_kernel.h | 42 ++++ .../phi/kernels/hierarchical_sigmoid_kernel.h | 40 ++++ .../hierarchical_sigmoid_grad_kernel.cc | 99 ++++++++ .../hierarchical_sigmoid_grad_kernel.h | 45 ++++ .../ops/compat/hierarchical_sigmoid_sig.cc | 83 +++++++ 15 files changed, 696 insertions(+), 271 deletions(-) delete mode 100644 paddle/fluid/operators/hierarchical_sigmoid_op.h create mode 100644 paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h create mode 100644 paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc create mode 100644 paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h create mode 100644 paddle/phi/kernels/hierarchical_sigmoid_kernel.h create mode 100644 paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc create mode 100644 paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h create mode 100644 paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ec28c98d598..42fbeb5d29c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -628,10 +628,12 @@ std::vector ExecutionContext::MultiOutput( bool OpSupportGPU(const std::string& op_type) { // check in new Function kernel first + bool has_phi_kernel = false; auto& kernel_factory = phi::KernelFactory::Instance(); auto kernel_key_map = kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type)); for (auto& kernel : kernel_key_map) { + has_phi_kernel = true; if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) { return true; } @@ -639,12 +641,19 @@ bool OpSupportGPU(const std::string& op_type) { auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto it = all_kernels.find(op_type); - if (it == all_kernels.end()) { - // All control operator must support GPU - return true; - } - for (auto& kern_pair : it->second) { - if (platform::is_gpu_place(kern_pair.first.place_)) { + if (it != all_kernels.end()) { + for (auto& kern_pair : it->second) { + if (platform::is_gpu_place(kern_pair.first.place_)) { + return true; + } + } + } else { + if (has_phi_kernel) { + // if has phi kernel, but not find phi gpu kernel and fluid gpu kernel, + // this op doesn't support GPU + return false; + } else { + // All control operator must support GPU return true; } } @@ -2347,6 +2356,10 @@ void OperatorWithKernel::BuildPhiKernelContext( const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr_it->second); pt_kernel_context->EmplaceBackAttr(vector_int_attr); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr_it->second)); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index f70f44878e3..9daac181d57 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -541,6 +541,10 @@ void BuildDygraphPhiKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 9575ab54b32..93f0d3d334f 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/hierarchical_sigmoid_op.h" #include #include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -60,31 +64,6 @@ namespace operators { class HierarchicalSigmoidOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "hsigmoid"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid"); - OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "hsigmoid"); - OP_INOUT_CHECK(ctx->HasOutput("PreOut"), "Output", "PreOut", "hsigmoid"); - - auto with_prefetch = ctx->Attrs().Get("remote_prefetch"); - if (with_prefetch) { - OP_INOUT_CHECK(ctx->HasOutput("W_Out"), "Output", "W_Out", "hsigmoid"); - } - const int64_t input_dims = ctx->GetInputDim("X")[0]; - const int64_t label_dims = ctx->GetInputDim("Label")[0]; - PADDLE_ENFORCE_EQ(input_dims, label_dims, - platform::errors::InvalidArgument( - "The first dimension of " - "input and label is expected to be the same. " - "But received input's first dimension is %d; " - "label's first dimension is %d.", - input_dims, label_dims)); - - std::vector output_shape({input_dims, 1}); - ctx->SetOutputDim("Out", phi::make_ddim(output_shape)); - ctx->ShareLoD("X", /*->*/ "Out"); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -272,22 +251,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER( } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - hierarchical_sigmoid, ops::HierarchicalSigmoidOp, - ops::HierarchicalSigmoidOpMaker, - ops::HierarchicalSigmoidGradMaker, - ops::HierarchicalSigmoidGradMaker); +DECLARE_INFER_SHAPE_FUNCTOR(hierarchical_sigmoid, + HierarchicalSigmoidInferShapeFunctor, + PD_INFER_META(phi::HierarchicalSigmoidInferMeta)); +REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, + ops::HierarchicalSigmoidOpMaker, + ops::HierarchicalSigmoidGradMaker, + ops::HierarchicalSigmoidGradMaker, + HierarchicalSigmoidInferShapeFunctor); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, ops::HierarchicalSigmoidGradOpGradVarTypeInference, ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid, - ops::HierarchicalSigmoidOpKernel, - ops::HierarchicalSigmoidOpKernel); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid_grad, - ops::HierarchicalSigmoidGradOpKernel, - ops::HierarchicalSigmoidGradOpKernel); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h deleted file mode 100644 index f11b28cfefb..00000000000 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/clip_op.h" -#include "paddle/fluid/operators/math/matrix_bit_code.h" -#include "paddle/fluid/platform/transform.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -using EigenMatrix = framework::EigenMatrix; -using platform::Transform; -using framework::LoDTensor; - -static std::vector PathToRows(const LoDTensor& path) { - std::set rows; - const int64_t* paths = path.data(); - for (int64_t i = 0; i < path.numel(); ++i) { - int64_t row = paths[i]; - if (row < 0) { - continue; - } - rows.emplace(row); - } - return std::vector(rows.begin(), rows.end()); -} -template -class HierarchicalSigmoidOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& in = GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", - "HierarchicalSigmoid"); - auto& w = GET_DATA_SAFELY(ctx.Input("W"), "Input", "W", - "HierarchicalSigmoid"); - auto* path = ctx.Input("PathTable"); - auto* code = ctx.Input("PathCode"); - auto& label = GET_DATA_SAFELY(ctx.Input("Label"), "Input", - "Label", "HierarchicalSigmoid"); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - auto* pre_out = ctx.Output("PreOut"); - size_t num_classes = static_cast(ctx.Attr("num_classes")); - // for remote prefetch - - bool is_custom = false; - if (path) { - is_custom = true; - } - int64_t code_length = - path ? path->dims()[1] : math::FindLastSet(num_classes - 1); - int64_t batch_size = in.dims()[0]; - LoDTensor sum; - auto& dev_ctx = ctx.template device_context(); - auto* pre_out_data = pre_out->mutable_data( - phi::make_ddim({batch_size, code_length}), ctx.GetPlace()); - auto pre_out_mat = EigenMatrix::From(*pre_out); - // Not all class(leaf) nodes' path lengths equal code_length, thus init as - // 0s can avoid out of path's loss. - phi::funcs::SetConstant zero; - zero(dev_ctx, pre_out, static_cast(0.0)); - auto& place = *ctx.template device_context().eigen_device(); - phi::funcs::RowwiseSum row_sum; - - std::unique_ptr> bit_code; - if (!is_custom) { - bit_code.reset(new math::MatrixBitCodeFunctor( - num_classes, label.template data())); - } else { - bit_code.reset(new math::MatrixBitCodeFunctor( - *path, *code, label.template data())); - } - - std::vector sum_dims({batch_size, 1UL}); - sum.mutable_data(phi::make_ddim(sum_dims), ctx.GetPlace()); - auto sum_mat = EigenMatrix::From(sum); - out->mutable_data(ctx.GetPlace()); - auto out_mat = framework::EigenMatrix::From(*out); - if (bias) { - bit_code->Add(*bias, pre_out); - } - bit_code->Mul(pre_out, w, in); - // clip to [-40, 40] - Transform trans; - trans(ctx.template device_context(), pre_out_data, - pre_out_data + pre_out->numel(), pre_out_data, - ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code->Sum(*pre_out, out, static_cast(-1)); - // use softrelu to calculate cross entropy - pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); - row_sum(dev_ctx, *pre_out, &sum); - // TODO(guosheng): Subtract the out of path's loss, since not all - // class(leaf) nodes' path lengths equal code_length. But it won't break the - // gradient check since both have the out of path's loss and will cancel out - // each other. - out_mat.device(place) = sum_mat + out_mat; - } -}; - -template -class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& in = GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", - "HierarchicalSigmoidGrad"); - auto& w = GET_DATA_SAFELY(ctx.Input("W"), "Input", "W", - "HierarchicalSigmoidGrad"); - auto* path = ctx.Input("PathTable"); - auto* code = ctx.Input("PathCode"); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - bool is_sparse = ctx.Attr("is_sparse"); - auto& dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant zero; - auto& label = GET_DATA_SAFELY(ctx.Input("Label"), "Input", - "Label", "HierarchicalSigmoidGrad"); - auto& pre_out = GET_DATA_SAFELY(ctx.Input("PreOut"), "Input", - "PreOut", "HierarchicalSigmoidGrad"); - auto& out_grad = GET_DATA_SAFELY( - ctx.Input(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "HierarchicalSigmoidGrad"); - LoDTensor pre_out_grad; - - pre_out_grad.mutable_data(pre_out.dims(), ctx.GetPlace()); - in_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, in_grad, static_cast(0.0)); - - size_t num_classes = static_cast(ctx.Attr("num_classes")); - - bool is_custom = false; - if (path) { - is_custom = true; - } - - std::unique_ptr> bit_code; - if (!is_custom) { - bit_code.reset(new math::MatrixBitCodeFunctor( - num_classes, label.template data())); - } else { - bit_code.reset(new math::MatrixBitCodeFunctor( - *path, *code, label.template data())); - } - - // softrelu derivative - - auto blas = phi::funcs::GetBlas(ctx); - - auto* pre_out_grad_data = pre_out_grad.data(); - auto* pre_out_data = pre_out.template data(); - auto n = pre_out.numel(); - blas.VEXP(n, pre_out_data, pre_out_grad_data); - blas.VINV(n, pre_out_grad_data, pre_out_grad_data); - for (int64_t i = 0; i < n; ++i) { - pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; - } - bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) - auto* out_grad_data = out_grad.template data(); - - int64_t dim0 = pre_out_grad.dims()[0]; - int64_t dim1 = pre_out_grad.dims()[1]; - for (int64_t i = 0; i < dim0; ++i) { - T tmp = out_grad_data[i]; - blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1); - } - // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to - // be consistent with the clipping in forward. - auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); - if (bias_grad) { - bias_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, bias_grad, static_cast(0.0)); - bit_code->AddGrad(pre_out_grad, bias_grad); - } - if (!is_sparse) { - auto* w_grad = ctx.Output(framework::GradVarName("W")); - w_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, w_grad, static_cast(0.0)); - bit_code->MulGradWeight(pre_out_grad, w_grad, in); - } else { - PADDLE_ENFORCE_NOT_NULL(path, - platform::errors::NotFound( - "Custom tree must be set for sparse mode!")); - framework::Vector real_rows = PathToRows(*path); - auto* w_grad = ctx.Output(framework::GradVarName("W")); - w_grad->set_rows(real_rows); - // Build a map of id -> row_index to speed up finding the index of one id - w_grad->set_height(w.dims()[0]); - auto* w_grad_value = w_grad->mutable_value(); - framework::DDim temp_dim(w.dims()); - temp_dim[0] = real_rows.size(); - w_grad_value->mutable_data(temp_dim, ctx.GetPlace()); - zero(dev_ctx, w_grad_value, static_cast(0.0)); - bit_code->MulGradWeight(pre_out_grad, w_grad, in); - } - bit_code->MulGradError(pre_out_grad, w, in_grad); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index ef75ab573c6..3f77a20af22 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -369,6 +369,40 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void HierarchicalSigmoidInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& label, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + MetaTensor* out, + MetaTensor* pre_out, + MetaTensor* w_out) { + const int64_t input_dims = x.dims()[0]; + const int64_t label_dims = label.dims()[0]; + PADDLE_ENFORCE_EQ(input_dims, + label_dims, + phi::errors::InvalidArgument( + "The first dimension of " + "input and label is expected to be the same. " + "But received input's first dimension is %d; " + "label's first dimension is %d.", + input_dims, + label_dims)); + + std::vector output_shape({input_dims, 1}); + out->set_dims(phi::make_ddim(output_shape)); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { auto inputs_dims = GetMetaTensorsDim(x); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 6de95386dd9..a712ca31de7 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -87,6 +87,23 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void HierarchicalSigmoidInferMeta(const MetaTensor& x, + const MetaTensor& w, + const MetaTensor& label, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + MetaTensor* out, + MetaTensor* pre_out, + MetaTensor* w_out); + void MultiDotInferMeta(const std::vector& x, MetaTensor* out); void PsroiPoolInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index aa76561c5ce..d140912aa78 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,12 +27,15 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel +set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel + hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel) kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) +kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) +kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code) kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(reduce_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel) diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h new file mode 100644 index 00000000000..b79aab96c0f --- /dev/null +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +namespace math = paddle::operators::math; + +template +void HierarchicalSigmoidGradKernelImpl( + const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + const DenseTensor& pre_out, + const DenseTensor& out_grad, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* x_grad, + DenseTensor* w_grad, + DenseTensor* bias_grad, + SelectedRows* w_grad_sr = nullptr) { + funcs::SetConstant zero; + DenseTensor pre_out_grad; + + pre_out_grad.Resize(pre_out.dims()); + ctx.template Alloc(&pre_out_grad); + ctx.template Alloc(x_grad); + zero(ctx, x_grad, static_cast(0.0)); + + bool is_custom = false; + if (path.get_ptr()) { + is_custom = true; + } + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor( + num_classes, label.template data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor( + *(path.get_ptr()), *(code.get_ptr()), label.template data())); + } + + // softrelu derivative + + auto blas = funcs::GetBlas(ctx); + + auto* pre_out_grad_data = pre_out_grad.data(); + auto* pre_out_data = pre_out.template data(); + auto n = pre_out.numel(); + blas.VEXP(n, pre_out_data, pre_out_grad_data); + blas.VINV(n, pre_out_grad_data, pre_out_grad_data); + for (int64_t i = 0; i < n; ++i) { + pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; + } + bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) + auto* out_grad_data = out_grad.template data(); + + int64_t dim0 = pre_out_grad.dims()[0]; + int64_t dim1 = pre_out_grad.dims()[1]; + for (int64_t i = 0; i < dim0; ++i) { + T tmp = out_grad_data[i]; + blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1); + } + // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to + // be consistent with the clipping in forward. + if (bias_grad) { + ctx.template Alloc(bias_grad); + zero(ctx, bias_grad, static_cast(0.0)); + bit_code->AddGrad(pre_out_grad, bias_grad); + } + ctx.template Alloc(w_grad); + zero(ctx, w_grad, static_cast(0.0)); + if (!is_sparse) { + bit_code->MulGradWeight(pre_out_grad, w_grad, x); + } else { + bit_code->MulGradWeight(pre_out_grad, w_grad_sr, x); + } + bit_code->MulGradError(pre_out_grad, w, x_grad); +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc new file mode 100644 index 00000000000..f64a1a8162a --- /dev/null +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_grad_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h" + +namespace phi { + +template +void HierarchicalSigmoidGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + const DenseTensor& pre_out, + const DenseTensor& out_grad, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* x_grad, + DenseTensor* w_grad, + DenseTensor* bias_grad) { + HierarchicalSigmoidGradKernelImpl(ctx, + x, + w, + label, + pre_out, + out_grad, + path, + code, + bias, + num_classes, + remote_prefetch, + trainer_id, + height_sections, + epmap, + table_names, + is_sparse, + x_grad, + w_grad, + bias_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL(hierarchical_sigmoid_grad, + CPU, + ALL_LAYOUT, + phi::HierarchicalSigmoidGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc new file mode 100644 index 00000000000..096a54f9fb2 --- /dev/null +++ b/paddle/phi/kernels/cpu/hierarchical_sigmoid_kernel.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h" + +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/math/matrix_bit_code.h" +#include "paddle/fluid/platform/transform.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function_impl.h" + +namespace phi { + +namespace math = paddle::operators::math; + +template +void HierarchicalSigmoidKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* out, + DenseTensor* pre_out, + DenseTensor* w_out) { + size_t num_classes_st = static_cast(num_classes); + // for remote prefetch + + bool is_custom = false; + if (path.get_ptr()) { + is_custom = true; + } + int64_t code_length = path.get_ptr() ? path.get_ptr()->dims()[1] + : math::FindLastSet(num_classes_st - 1); + int64_t batch_size = x.dims()[0]; + DenseTensor sum; + pre_out->Resize(phi::make_ddim({batch_size, code_length})); + ctx.template Alloc(pre_out); + auto* pre_out_data = pre_out->data(); + auto pre_out_mat = EigenMatrix::From(*pre_out); + // Not all class(leaf) nodes' path lengths equal code_length, thus init as + // 0s can avoid out of path's loss. + funcs::SetConstant zero; + zero(ctx, pre_out, static_cast(0.0)); + auto& place = *ctx.eigen_device(); + funcs::RowwiseSum row_sum; + + std::unique_ptr> bit_code; + if (!is_custom) { + bit_code.reset(new math::MatrixBitCodeFunctor( + num_classes_st, label.template data())); + } else { + bit_code.reset(new math::MatrixBitCodeFunctor( + *(path.get_ptr()), *(code.get_ptr()), label.template data())); + } + + std::vector sum_dims({batch_size, 1UL}); + sum.Resize(phi::make_ddim(sum_dims)); + ctx.template Alloc(&sum); + auto sum_mat = EigenMatrix::From(sum); + ctx.template Alloc(out); + auto out_mat = EigenMatrix::From(*out); + if (bias.get_ptr()) { + bit_code->Add(*(bias.get_ptr()), pre_out); + } + bit_code->Mul(pre_out, w, x); + // clip to [-40, 40] + paddle::platform::Transform trans; + trans(ctx, + pre_out_data, + pre_out_data + pre_out->numel(), + pre_out_data, + paddle::operators::ClipFunctor(static_cast(-40.0), + static_cast(40.0))); + bit_code->Sum(*pre_out, out, static_cast(-1)); + // use softrelu to calculate cross entropy + pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); + row_sum(ctx, *pre_out, &sum); + // TODO(guosheng): Subtract the out of path's loss, since not all + // class(leaf) nodes' path lengths equal code_length. But it won't break the + // gradient check since both have the out of path's loss and will cancel out + // each other. + out_mat.device(place) = sum_mat + out_mat; +} + +} // namespace phi + +PD_REGISTER_KERNEL(hierarchical_sigmoid, + CPU, + ALL_LAYOUT, + phi::HierarchicalSigmoidKernel, + float, + double) {} diff --git a/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h b/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h new file mode 100644 index 00000000000..f7a327cd3f5 --- /dev/null +++ b/paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void HierarchicalSigmoidGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + const DenseTensor& pre_out, + const DenseTensor& out_grad, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* x_grad, + DenseTensor* w_grad, + DenseTensor* bias_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/hierarchical_sigmoid_kernel.h b/paddle/phi/kernels/hierarchical_sigmoid_kernel.h new file mode 100644 index 00000000000..619b022904b --- /dev/null +++ b/paddle/phi/kernels/hierarchical_sigmoid_kernel.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void HierarchicalSigmoidKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* out, + DenseTensor* pre_out, + DenseTensor* w_out); + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc new file mode 100644 index 00000000000..80b2a1f6678 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h" + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h" + +namespace phi { +namespace sr { + +static std::vector PathToRows(const DenseTensor& path) { + std::set rows; + const int64_t* paths = path.data(); + for (int64_t i = 0; i < path.numel(); ++i) { + int64_t row = paths[i]; + if (row < 0) { + continue; + } + rows.emplace(row); + } + return std::vector(rows.begin(), rows.end()); +} + +template +void HierarchicalSigmoidGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + const DenseTensor& pre_out, + const DenseTensor& out_grad, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* x_grad, + SelectedRows* w_grad, + DenseTensor* bias_grad) { + PADDLE_ENFORCE_NOT_NULL( + path.get_ptr(), + errors::NotFound("Custom tree must be set for sparse mode!")); + paddle::framework::Vector real_rows = PathToRows(*path); + w_grad->set_rows(real_rows); + // Build a map of id -> row_index to speed up finding the index of one id + w_grad->set_height(w.dims()[0]); + auto* w_grad_value = w_grad->mutable_value(); + phi::DDim temp_dim(w.dims()); + temp_dim[0] = real_rows.size(); + w_grad_value->Resize(temp_dim); + phi::HierarchicalSigmoidGradKernelImpl(ctx, + x, + w, + label, + pre_out, + out_grad, + path, + code, + bias, + num_classes, + remote_prefetch, + trainer_id, + height_sections, + epmap, + table_names, + is_sparse, + x_grad, + w_grad_value, + bias_grad, + w_grad); +} + +} // namespace sr +} // namespace phi + +PD_REGISTER_KERNEL(hierarchical_sigmoid_grad_sr, + CPU, + ALL_LAYOUT, + phi::sr::HierarchicalSigmoidGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h new file mode 100644 index 00000000000..557c8b1bc5e --- /dev/null +++ b/paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void HierarchicalSigmoidGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& w, + const DenseTensor& label, + const DenseTensor& pre_out, + const DenseTensor& out_grad, + paddle::optional path, + paddle::optional code, + paddle::optional bias, + int num_classes, + bool remote_prefetch, + int trainer_id, + const std::vector& height_sections, + const std::vector& epmap, + const std::vector& table_names, + bool is_sparse, + DenseTensor* x_grad, + SelectedRows* w_grad, + DenseTensor* bias_grad); + +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc new file mode 100644 index 00000000000..20183d1a9b0 --- /dev/null +++ b/paddle/phi/ops/compat/hierarchical_sigmoid_sig.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature HierarchicalSigmoidOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("hierarchical_sigmoid", + {"X", "W", "Label", "PathTable", "PathCode", "Bias"}, + {"num_classes", + "remote_prefetch", + "trainer_id", + "height_sections", + "epmap", + "table_names", + "is_sparse"}, + {"Out", "PreOut", "W_Out"}); +} + +KernelSignature HierarchicalSigmoidGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorOutput(GradVarName("W"))) { + return KernelSignature( + "hierarchical_sigmoid_grad", + {"X", + "W", + "Label", + "PreOut", + GradVarName("Out"), + "PathTable", + "PathCode", + "Bias"}, + {"num_classes", + "remote_prefetch", + "trainer_id", + "height_sections", + "epmap", + "table_names", + "is_sparse"}, + {GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); + } else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) { + return KernelSignature( + "hierarchical_sigmoid_grad_sr", + {"X", + "W", + "Label", + "PreOut", + GradVarName("Out"), + "PathTable", + "PathCode", + "Bias"}, + {"num_classes", + "remote_prefetch", + "trainer_id", + "height_sections", + "epmap", + "table_names", + "is_sparse"}, + {GradVarName("X"), GradVarName("W"), GradVarName("Bias")}); + } else { + return KernelSignature("unregistered", {}, {}, {}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(hierarchical_sigmoid, + phi::HierarchicalSigmoidOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(hierarchical_sigmoid_grad, + phi::HierarchicalSigmoidGradOpArgumentMapping); -- GitLab