未验证 提交 64a7cbd3 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move hierarchical_sigmoid kernel to phi (#40553)

* first commit

* fix compile error

* support std::vector<std::srting>

* fix

* fix op support on GPU by chenweihang

* pass test

* infershape

* add set_dtype

* fix order

* fix

* unify the impl of dt and sr

* fix
上级 161d27dc
......@@ -628,10 +628,12 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
bool OpSupportGPU(const std::string& op_type) {
// check in new Function kernel first
bool has_phi_kernel = false;
auto& kernel_factory = phi::KernelFactory::Instance();
auto kernel_key_map =
kernel_factory.SelectKernelMap(phi::TransToPhiKernelName(op_type));
for (auto& kernel : kernel_key_map) {
has_phi_kernel = true;
if (platform::is_gpu_place(phi::TransToPhiPlace(kernel.first.backend()))) {
return true;
}
......@@ -639,12 +641,19 @@ bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
return true;
}
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
return true;
}
}
} else {
if (has_phi_kernel) {
// if has phi kernel, but not find phi gpu kernel and fluid gpu kernel,
// this op doesn't support GPU
return false;
} else {
// All control operator must support GPU
return true;
}
}
......@@ -2347,6 +2356,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -541,6 +541,10 @@ void BuildDygraphPhiKernelContext(
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
......
......@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -60,31 +64,6 @@ namespace operators {
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "hsigmoid");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "hsigmoid");
OP_INOUT_CHECK(ctx->HasOutput("PreOut"), "Output", "PreOut", "hsigmoid");
auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
if (with_prefetch) {
OP_INOUT_CHECK(ctx->HasOutput("W_Out"), "Output", "W_Out", "hsigmoid");
}
const int64_t input_dims = ctx->GetInputDim("X")[0];
const int64_t label_dims = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(input_dims, label_dims,
platform::errors::InvalidArgument(
"The first dimension of "
"input and label is expected to be the same. "
"But received input's first dimension is %d; "
"label's first dimension is %d.",
input_dims, label_dims));
std::vector<int64_t> output_shape({input_dims, 1});
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -272,22 +251,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
DECLARE_INFER_SHAPE_FUNCTOR(hierarchical_sigmoid,
HierarchicalSigmoidInferShapeFunctor,
PD_INFER_META(phi::HierarchicalSigmoidInferMeta));
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
ops::HierarchicalSigmoidOpMaker<int>,
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>,
HierarchicalSigmoidInferShapeFunctor);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid_grad,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <iterator>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform;
using framework::LoDTensor;
static std::vector<int64_t> PathToRows(const LoDTensor& path) {
std::set<int64_t> rows;
const int64_t* paths = path.data<int64_t>();
for (int64_t i = 0; i < path.numel(); ++i) {
int64_t row = paths[i];
if (row < 0) {
continue;
}
rows.emplace(row);
}
return std::vector<int64_t>(rows.begin(), rows.end());
}
template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
"HierarchicalSigmoid");
auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
"HierarchicalSigmoid");
auto* path = ctx.Input<LoDTensor>("PathTable");
auto* code = ctx.Input<LoDTensor>("PathCode");
auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
"Label", "HierarchicalSigmoid");
auto* bias = ctx.Input<LoDTensor>("Bias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* pre_out = ctx.Output<LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch
bool is_custom = false;
if (path) {
is_custom = true;
}
int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in.dims()[0];
LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>(
phi::make_ddim({batch_size, code_length}), ctx.GetPlace());
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
phi::funcs::SetConstant<DeviceContext, T> zero;
zero(dev_ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
phi::funcs::RowwiseSum<DeviceContext, T> row_sum;
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*path, *code, label.template data<int64_t>()));
}
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.mutable_data<T>(phi::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenMatrix<T>::From(*out);
if (bias) {
bit_code->Add(*bias, pre_out);
}
bit_code->Mul(pre_out, w, in);
// clip to [-40, 40]
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
pre_out_data + pre_out->numel(), pre_out_data,
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(dev_ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
};
template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
"HierarchicalSigmoidGrad");
auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
"HierarchicalSigmoidGrad");
auto* path = ctx.Input<LoDTensor>("PathTable");
auto* code = ctx.Input<LoDTensor>("PathCode");
auto* in_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> zero;
auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
"Label", "HierarchicalSigmoidGrad");
auto& pre_out = GET_DATA_SAFELY(ctx.Input<LoDTensor>("PreOut"), "Input",
"PreOut", "HierarchicalSigmoidGrad");
auto& out_grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "HierarchicalSigmoidGrad");
LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, in_grad, static_cast<T>(0.0));
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
bool is_custom = false;
if (path) {
is_custom = true;
}
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*path, *code, label.template data<int64_t>()));
}
// softrelu derivative
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.template data<T>();
auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
for (int64_t i = 0; i < n; ++i) {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
}
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
auto* out_grad_data = out_grad.template data<T>();
int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
T tmp = out_grad_data[i];
blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
}
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
auto* bias_grad = ctx.Output<LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
if (!is_sparse) {
auto* w_grad = ctx.Output<LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
} else {
PADDLE_ENFORCE_NOT_NULL(path,
platform::errors::NotFound(
"Custom tree must be set for sparse mode!"));
framework::Vector<int64_t> real_rows = PathToRows(*path);
auto* w_grad = ctx.Output<phi::SelectedRows>(framework::GradVarName("W"));
w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id
w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value();
framework::DDim temp_dim(w.dims());
temp_dim[0] = real_rows.size();
w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in);
}
bit_code->MulGradError(pre_out_grad, w, in_grad);
}
};
} // namespace operators
} // namespace paddle
......@@ -369,6 +369,40 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& label,
paddle::optional<const MetaTensor&> path,
paddle::optional<const MetaTensor&> code,
paddle::optional<const MetaTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
MetaTensor* out,
MetaTensor* pre_out,
MetaTensor* w_out) {
const int64_t input_dims = x.dims()[0];
const int64_t label_dims = label.dims()[0];
PADDLE_ENFORCE_EQ(input_dims,
label_dims,
phi::errors::InvalidArgument(
"The first dimension of "
"input and label is expected to be the same. "
"But received input's first dimension is %d; "
"label's first dimension is %d.",
input_dims,
label_dims));
std::vector<int64_t> output_shape({input_dims, 1});
out->set_dims(phi::make_ddim(output_shape));
out->share_lod(x);
out->set_dtype(x.dtype());
}
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
......
......@@ -87,6 +87,23 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void HierarchicalSigmoidInferMeta(const MetaTensor& x,
const MetaTensor& w,
const MetaTensor& label,
paddle::optional<const MetaTensor&> path,
paddle::optional<const MetaTensor&> code,
paddle::optional<const MetaTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
MetaTensor* out,
MetaTensor* pre_out,
MetaTensor* w_out);
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void PsroiPoolInferMeta(const MetaTensor& x,
......
......@@ -27,12 +27,15 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel
set(MANUAL_BUILD_KERNELS eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel
hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(hierarchical_sigmoid_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
kernel_library(hierarchical_sigmoid_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_bit_code)
kernel_library(gumbel_softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(gumbel_softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(reduce_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace math = paddle::operators::math;
template <typename T, typename Context>
void HierarchicalSigmoidGradKernelImpl(
const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
const DenseTensor& pre_out,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* x_grad,
DenseTensor* w_grad,
DenseTensor* bias_grad,
SelectedRows* w_grad_sr = nullptr) {
funcs::SetConstant<Context, T> zero;
DenseTensor pre_out_grad;
pre_out_grad.Resize(pre_out.dims());
ctx.template Alloc<T>(&pre_out_grad);
ctx.template Alloc<T>(x_grad);
zero(ctx, x_grad, static_cast<T>(0.0));
bool is_custom = false;
if (path.get_ptr()) {
is_custom = true;
}
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>()));
}
// softrelu derivative
auto blas = funcs::GetBlas<Context, T>(ctx);
auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.template data<T>();
auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
for (int64_t i = 0; i < n; ++i) {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
}
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
auto* out_grad_data = out_grad.template data<T>();
int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1];
for (int64_t i = 0; i < dim0; ++i) {
T tmp = out_grad_data[i];
blas.SCAL(dim1, tmp, pre_out_grad_data + i * dim1);
}
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward.
if (bias_grad) {
ctx.template Alloc<T>(bias_grad);
zero(ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad);
}
ctx.template Alloc<T>(w_grad);
zero(ctx, w_grad, static_cast<T>(0.0));
if (!is_sparse) {
bit_code->MulGradWeight(pre_out_grad, w_grad, x);
} else {
bit_code->MulGradWeight(pre_out_grad, w_grad_sr, x);
}
bit_code->MulGradError(pre_out_grad, w, x_grad);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/hierarchical_sigmoid_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h"
namespace phi {
template <typename T, typename Context>
void HierarchicalSigmoidGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
const DenseTensor& pre_out,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* x_grad,
DenseTensor* w_grad,
DenseTensor* bias_grad) {
HierarchicalSigmoidGradKernelImpl<T>(ctx,
x,
w,
label,
pre_out,
out_grad,
path,
code,
bias,
num_classes,
remote_prefetch,
trainer_id,
height_sections,
epmap,
table_names,
is_sparse,
x_grad,
w_grad,
bias_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(hierarchical_sigmoid_grad,
CPU,
ALL_LAYOUT,
phi::HierarchicalSigmoidGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
namespace phi {
namespace math = paddle::operators::math;
template <typename T, typename Context>
void HierarchicalSigmoidKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* out,
DenseTensor* pre_out,
DenseTensor* w_out) {
size_t num_classes_st = static_cast<size_t>(num_classes);
// for remote prefetch
bool is_custom = false;
if (path.get_ptr()) {
is_custom = true;
}
int64_t code_length = path.get_ptr() ? path.get_ptr()->dims()[1]
: math::FindLastSet(num_classes_st - 1);
int64_t batch_size = x.dims()[0];
DenseTensor sum;
pre_out->Resize(phi::make_ddim({batch_size, code_length}));
ctx.template Alloc<T>(pre_out);
auto* pre_out_data = pre_out->data<T>();
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
// 0s can avoid out of path's loss.
funcs::SetConstant<Context, T> zero;
zero(ctx, pre_out, static_cast<T>(0.0));
auto& place = *ctx.eigen_device();
funcs::RowwiseSum<Context, T> row_sum;
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
num_classes_st, label.template data<int64_t>()));
} else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(
*(path.get_ptr()), *(code.get_ptr()), label.template data<int64_t>()));
}
std::vector<int64_t> sum_dims({batch_size, 1UL});
sum.Resize(phi::make_ddim(sum_dims));
ctx.template Alloc<T>(&sum);
auto sum_mat = EigenMatrix<T>::From(sum);
ctx.template Alloc<T>(out);
auto out_mat = EigenMatrix<T>::From(*out);
if (bias.get_ptr()) {
bit_code->Add(*(bias.get_ptr()), pre_out);
}
bit_code->Mul(pre_out, w, x);
// clip to [-40, 40]
paddle::platform::Transform<Context> trans;
trans(ctx,
pre_out_data,
pre_out_data + pre_out->numel(),
pre_out_data,
paddle::operators::ClipFunctor<T>(static_cast<T>(-40.0),
static_cast<T>(40.0)));
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
row_sum(ctx, *pre_out, &sum);
// TODO(guosheng): Subtract the out of path's loss, since not all
// class(leaf) nodes' path lengths equal code_length. But it won't break the
// gradient check since both have the out of path's loss and will cancel out
// each other.
out_mat.device(place) = sum_mat + out_mat;
}
} // namespace phi
PD_REGISTER_KERNEL(hierarchical_sigmoid,
CPU,
ALL_LAYOUT,
phi::HierarchicalSigmoidKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void HierarchicalSigmoidGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
const DenseTensor& pre_out,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* x_grad,
DenseTensor* w_grad,
DenseTensor* bias_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void HierarchicalSigmoidKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* out,
DenseTensor* pre_out,
DenseTensor* w_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/hierarchical_sigmoid_grad_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/hierarchical_sigmoid_grad.h"
namespace phi {
namespace sr {
static std::vector<int64_t> PathToRows(const DenseTensor& path) {
std::set<int64_t> rows;
const int64_t* paths = path.data<int64_t>();
for (int64_t i = 0; i < path.numel(); ++i) {
int64_t row = paths[i];
if (row < 0) {
continue;
}
rows.emplace(row);
}
return std::vector<int64_t>(rows.begin(), rows.end());
}
template <typename T, typename Context>
void HierarchicalSigmoidGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
const DenseTensor& pre_out,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* x_grad,
SelectedRows* w_grad,
DenseTensor* bias_grad) {
PADDLE_ENFORCE_NOT_NULL(
path.get_ptr(),
errors::NotFound("Custom tree must be set for sparse mode!"));
paddle::framework::Vector<int64_t> real_rows = PathToRows(*path);
w_grad->set_rows(real_rows);
// Build a map of id -> row_index to speed up finding the index of one id
w_grad->set_height(w.dims()[0]);
auto* w_grad_value = w_grad->mutable_value();
phi::DDim temp_dim(w.dims());
temp_dim[0] = real_rows.size();
w_grad_value->Resize(temp_dim);
phi::HierarchicalSigmoidGradKernelImpl<T>(ctx,
x,
w,
label,
pre_out,
out_grad,
path,
code,
bias,
num_classes,
remote_prefetch,
trainer_id,
height_sections,
epmap,
table_names,
is_sparse,
x_grad,
w_grad_value,
bias_grad,
w_grad);
}
} // namespace sr
} // namespace phi
PD_REGISTER_KERNEL(hierarchical_sigmoid_grad_sr,
CPU,
ALL_LAYOUT,
phi::sr::HierarchicalSigmoidGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void HierarchicalSigmoidGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& w,
const DenseTensor& label,
const DenseTensor& pre_out,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> path,
paddle::optional<const DenseTensor&> code,
paddle::optional<const DenseTensor&> bias,
int num_classes,
bool remote_prefetch,
int trainer_id,
const std::vector<int64_t>& height_sections,
const std::vector<std::string>& epmap,
const std::vector<std::string>& table_names,
bool is_sparse,
DenseTensor* x_grad,
SelectedRows* w_grad,
DenseTensor* bias_grad);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature HierarchicalSigmoidOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("hierarchical_sigmoid",
{"X", "W", "Label", "PathTable", "PathCode", "Bias"},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{"Out", "PreOut", "W_Out"});
}
KernelSignature HierarchicalSigmoidGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad",
{"X",
"W",
"Label",
"PreOut",
GradVarName("Out"),
"PathTable",
"PathCode",
"Bias"},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else if (ctx.IsSelectedRowsOutput(GradVarName("W"))) {
return KernelSignature(
"hierarchical_sigmoid_grad_sr",
{"X",
"W",
"Label",
"PreOut",
GradVarName("Out"),
"PathTable",
"PathCode",
"Bias"},
{"num_classes",
"remote_prefetch",
"trainer_id",
"height_sections",
"epmap",
"table_names",
"is_sparse"},
{GradVarName("X"), GradVarName("W"), GradVarName("Bias")});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(hierarchical_sigmoid,
phi::HierarchicalSigmoidOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hierarchical_sigmoid_grad,
phi::HierarchicalSigmoidGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册