未验证 提交 6982871d 编写于 作者: S seemingwang 提交者: GitHub

renorm op (#38130)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs

* fix MultiSlotDataGenerator error

* cache for graph engine

* fix type compare error

* more test&fix thread terminating problem

* remove header

* change time interval of shrink

* use cache when sample nodes

* remove unused function

* change unique_ptr to shared_ptr

* simplify cache template

* cache api on client

* fix

* reduce sample threads when cache is not used

* reduce cache memory

* cache optimization

* remove test function

* remove extra fetch function

* graph-engine data transfer optimization

* support graph_split load&query

* remove logs

* change shards to pointer vector

* use inference

* remove test code

* renorm op

* simplify renorm op

* recover local changes

* recover renorm op kernel

* fix init

* add blanklines in renorm doc

* fix import

* fix import
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 ee69f437
......@@ -272,4 +272,4 @@ void RunGraphSplit() {
worker_ptr_->finalize_worker();
}
TEST(RunGraphSplit, Run) { RunGraphSplit(); }
\ No newline at end of file
TEST(RunGraphSplit, Run) { RunGraphSplit(); }
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/renorm_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
class RenormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using DDim = paddle::framework::DDim;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "abs");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "abs");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class RenormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of renorm op.");
AddOutput("Out", "(Tensor), The output tensor of renorm op.");
AddAttr<float>("p", "(float, norm's power");
AddAttr<int>("axis",
"int,the dimension to slice over to get the sub-tensors");
AddAttr<float>("max_norm", "(float, the norm upper-bound");
AddAttr<bool>("use_cudnn",
"(bool, default false) Only used in cudnn kernel, need "
"install cudnn")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Renorm Operator.
This operator is used to scale tensor sliced by axis if its p-norm execeeds maxnorm
)DOC");
}
};
class RenormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@Grad", "AbsGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@Grad", "AbsGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
template <typename T>
class RenormGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("renorm_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(renorm, ops::RenormOp, ops::RenormOpMaker,
ops::RenormGradMaker<paddle::framework::OpDesc>,
ops::RenormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(renorm_grad, ops::RenormGradOp);
REGISTER_OP_CPU_KERNEL(renorm, ops::CPURenormKernel<float>,
ops::CPURenormKernel<double>);
REGISTER_OP_CPU_KERNEL(renorm_grad, ops::CPURenormGradKernel<float>,
ops::CPURenormGradKernel<double>);
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/renorm_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "stdio.h"
namespace paddle {
namespace operators {
__device__ __forceinline__ float inline_pow(float base, float exponent) {
return pow(base, exponent);
}
__device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }
template <typename Tx, typename Ty = Tx>
struct UnsignedPowFunctor {
HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) {
this->porder = porder;
}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(inline_pow(inline_abs(x), static_cast<Tx>(porder)));
}
float porder;
};
template <typename T>
__global__ void RenormKernelFunc3(int64_t size, T* dim_value, float p,
float max_norm) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
if (i < size) {
T temp = pow(dim_value[i], (T)(1.0 / p));
dim_value[i] = 1.0;
if (temp > max_norm) dim_value[i] = max_norm / temp;
}
}
template <typename T>
__global__ void RenormKernelFunc4(T* x_data, T* out_data, int64_t size,
T* dim_value, int64_t dimension_each,
int64_t dim_divisor) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
auto dim_index = i / dim_divisor % dimension_each;
if (i < size) {
if (dim_value[dim_index] < 1.0)
out_data[i] = dim_value[dim_index] * x_data[i];
else
out_data[i] = x_data[i];
}
}
template <typename T>
__global__ void RenormGradKernelFunc1(T* x_data, T* dout_data, T* pow_value,
T* mul_value, int64_t size,
int64_t dimension_each, float p,
int64_t dim_divisor) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
auto dim_index = i / dim_divisor % dimension_each;
if (i < size) {
pow_value[i] = pow(abs(x_data[i]), (T)p);
mul_value[i] = x_data[i] * dout_data[i];
}
}
template <typename T>
__global__ void RenormGradKernelFunc2(T* x_data, T* dout_data, T* dx_data,
int64_t size, T* dim_value,
T* dim_power_sum, T* weight_derivative,
int64_t dimension_each, float p,
float max_norm, int64_t dim_divisor) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
auto dim_index = i / dim_divisor % dimension_each;
if (i < dimension_each) {
dim_power_sum[i] = 0;
auto temp = pow(dim_value[i], (T)(1.0 / p));
if (temp > max_norm) {
dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm;
dim_value[i] = max_norm / temp;
} else
dim_value[i] = 1.0;
}
__syncthreads();
if (i < size) {
dx_data[i] = dim_value[dim_index] * dout_data[i];
dx_data[i] = dx_data[i] +
weight_derivative[dim_index] * dim_power_sum[dim_index] *
pow(abs(x_data[i]), T(p - 1.0)) *
(x_data[i] >= 0 ? 1 : -1);
}
}
template <typename T>
class CUDARenormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
T* x_data = (T*)x->data<T>();
auto input_dims = x->dims();
float max_norm = context.Attr<float>("max_norm");
float p = context.Attr<float>("p");
int dim = context.Attr<int>("axis");
auto dimension_each = input_dims[dim];
auto dim_size = input_dims.size();
framework::Tensor pow_value, dim_value;
int64_t dim_divisor = 1, pre_mul = 1;
for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i];
for (int i = 0; i < dim; i++) pre_mul *= input_dims[i];
pow_value.Resize(
framework::make_ddim({pre_mul, dimension_each, dim_divisor}));
dim_value.Resize(framework::make_ddim({dimension_each}));
pow_value.mutable_data<T>(context.GetPlace());
out->Resize(framework::make_ddim(framework::vectorize(input_dims)));
T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
int block = std::min(numel, static_cast<int64_t>(256));
using MT = typename details::MPTypeTrait<T>::Type;
int grid = (numel + block - 1) / block;
int block2 = std::min(dimension_each, static_cast<int64_t>(256));
int grid2 = (dimension_each + block2 - 1) / block2;
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {&pow_value};
auto func = UnsignedPowFunctor<MT, T>(p);
const auto& cuda_ctx =
context.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, MT, T,
UnsignedPowFunctor<MT, T>>(
cuda_ctx, ins, &outs, func);
std::vector<int> reduce_axis = {0, 2};
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
pow_value, &dim_value, kps::IdentityFunctor<T>(), reduce_axis, stream);
RenormKernelFunc3<T><<<grid2, block2, 0, stream>>>(
numel, dim_value.mutable_data<T>(context.GetPlace()), p, max_norm);
RenormKernelFunc4<T><<<grid, block, 0, stream>>>(
x_data, out_data, numel, dim_value.mutable_data<T>(context.GetPlace()),
dimension_each, dim_divisor);
// platform::GpuStreamSync(stream);
}
};
template <typename T>
class CUDAGradRenormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
T* dout_data = (T*)d_out->data<T>();
T* x_data = (T*)x->data<T>();
auto input_dims = x->dims();
float max_norm = ctx.Attr<float>("max_norm");
float p = ctx.Attr<float>("p");
int dim = ctx.Attr<int>("axis");
auto dimension_each = input_dims[dim];
auto dim_size = input_dims.size();
int64_t dim_divisor = 1, pre_mul = 1;
for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i];
for (int i = 0; i < dim; i++) pre_mul *= input_dims[i];
d_x->Resize(framework::make_ddim(framework::vectorize(input_dims)));
T* dx_data = d_x->mutable_data<T>(ctx.GetPlace());
framework::Tensor pow_value, mul_value, dim_value, dim_power_sum,
weight_derivative;
pow_value.Resize(
framework::make_ddim({pre_mul, dimension_each, dim_divisor}));
mul_value.Resize(
framework::make_ddim({pre_mul, dimension_each, dim_divisor}));
dim_value.Resize(framework::make_ddim({dimension_each}));
dim_power_sum.Resize(framework::make_ddim({dimension_each}));
weight_derivative.Resize(framework::make_ddim({dimension_each}));
auto stream = ctx.cuda_device_context().stream();
int block = std::min(numel, static_cast<int64_t>(256));
int grid = (numel + block - 1) / block;
pow_value.mutable_data<T>(ctx.GetPlace());
mul_value.mutable_data<T>(ctx.GetPlace());
dim_value.mutable_data<T>(ctx.GetPlace());
dim_power_sum.mutable_data<T>(ctx.GetPlace());
weight_derivative.mutable_data<T>(ctx.GetPlace());
RenormGradKernelFunc1<T><<<grid, block, 0, stream>>>(
x_data, dout_data, pow_value.mutable_data<T>(ctx.GetPlace()),
mul_value.mutable_data<T>(ctx.GetPlace()), numel, dimension_each, p,
dim_divisor);
std::vector<int> reduce_axis = {0, 2};
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
pow_value, &dim_value, kps::IdentityFunctor<T>(), reduce_axis, stream);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
mul_value, &weight_derivative, kps::IdentityFunctor<T>(), reduce_axis,
stream);
RenormGradKernelFunc2<T><<<grid, block, 0, stream>>>(
x_data, dout_data, dx_data, numel,
dim_value.mutable_data<T>(ctx.GetPlace()),
dim_power_sum.mutable_data<T>(ctx.GetPlace()),
weight_derivative.mutable_data<T>(ctx.GetPlace()), dimension_each, p,
max_norm, dim_divisor);
// platform::GpuStreamSync(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(renorm, ops::CUDARenormKernel<float>,
ops::CUDARenormKernel<double>);
REGISTER_OP_CUDA_KERNEL(renorm_grad, ops::CUDAGradRenormKernel<float>,
ops::CUDAGradRenormKernel<double>);
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "math.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// template <typename T>
// struct NormDimValueFunctor<T> {
// NormDimValueFunctor(T* input, T* output, int64_t dim_divisor, int64_t
// dimension_each, float p)
// : input_(input), output_(output),dim_divisor_(dim_divisor),
// dimension_each_(dimension_each),p_(p) {}
// HOSTDEVICE void operator()(int64_t i) const {
// auto dim_index = i / dim_divsor % dimension_each;
// dim_value[dim_index] += std::pow(std::abs(input[i]), p);
// }
// T* input_;
// T* output_;
// int64_t dimension_each_, dim_divisor_;
// float p_,max_norm_;
// };
// template <typename DeviceContext, typename T>
template <typename T>
class CPURenormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto input_dims = x->dims();
float max_norm = context.Attr<float>("max_norm");
float p = context.Attr<float>("p");
int dim = context.Attr<int>("axis");
auto dimension_each = input_dims[dim];
auto dim_size = input_dims.size();
int64_t dim_divisor = 1;
for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i];
// auto& dev_ctx = ctx.template device_context<DeviceContext>();
// std::vector<int64_t> dim_index(dim_size, 0);
std::vector<T> dim_value(dimension_each,
0); // dim_value = (x1^p + x2^p + x3^p....)^(1/p)
auto* out_data =
out->mutable_data<T>(context.GetPlace(), size_t(numel * sizeof(T)));
int64_t index = 0, dim_index = 0;
for (int64_t i = 0; i < numel; i++) {
// auto dim_index = i / dim_divsor % dimension_each;
dim_value[dim_index] += std::pow(std::abs(x_data[i]), p);
index++;
if (index == dim_divisor) {
dim_index++;
if (dim_index == dimension_each) {
dim_index = 0;
}
index = 0;
}
}
for (int64_t i = 0; i < dimension_each; i++) {
dim_value[i] = std::pow(dim_value[i], 1.0 / p);
if (dim_value[i] > max_norm)
dim_value[i] = max_norm / dim_value[i];
else
dim_value[i] = 1.0;
// dim_index[i] = 0;
}
index = dim_index = 0;
for (int64_t i = 0; i < numel; i++) {
// auto dim_index = i / dim_divsor % dimension_each;
out_data[i] = dim_value[dim_index] < 1.0
? dim_value[dim_index] * x_data[i]
: x_data[i];
index++;
if (index == dim_divisor) {
dim_index++;
if (dim_index == dimension_each) {
dim_index = 0;
}
index = 0;
}
}
}
};
// template <typename DeviceContext, typename T>
template <typename T>
class CPURenormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<T>();
auto* x_data = x->data<T>();
auto input_dims = x->dims();
float max_norm = ctx.Attr<float>("max_norm");
float p = ctx.Attr<float>("p");
int dim = ctx.Attr<int>("axis");
auto dimension_each = input_dims[dim];
auto dim_size = input_dims.size();
int64_t dim_divisor = 1;
for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i];
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
std::vector<T> dim_value(dimension_each, 0),
dim_power_sum(dimension_each, 0),
weight_derivative(dimension_each, 0.0);
int64_t index = 0, dim_index = 0;
for (int64_t i = 0; i < numel; i++) {
// auto dim_index = i / dim_divsor % dimension_each;
dim_value[dim_index] += std::pow(std::abs(x_data[i]), p);
index++;
if (index == dim_divisor) {
dim_index++;
if (dim_index == dimension_each) {
dim_index = 0;
}
index = 0;
}
}
for (int64_t i = 0; i < dimension_each; i++) {
auto temp = std::pow(dim_value[i], 1.0 / p);
if (temp > max_norm) {
dim_power_sum[i] =
std::pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm;
dim_value[i] = max_norm / temp;
} else
dim_value[i] = 1.0;
}
index = dim_index = 0;
for (int64_t i = 0; i < numel; i++) {
// auto dim_index = i / dim_divsor % dimension_each;
dx_data[i] = dim_value[dim_index] * dout_data[i];
weight_derivative[dim_index] += x_data[i] * dout_data[i];
index++;
if (index == dim_divisor) {
dim_index++;
if (dim_index == dimension_each) {
dim_index = 0;
}
index = 0;
}
}
index = dim_index = 0;
for (int64_t i = 0; i < numel; i++) {
// auto dim_index = i / dim_divsor % dimension_each;
dx_data[i] += weight_derivative[dim_index] * dim_power_sum[dim_index] *
std::pow(std::abs(x_data[i]), p - 1.0) *
(x_data[i] >= 0 ? 1 : -1);
index++;
if (index == dim_divisor) {
dim_index++;
if (dim_index == dimension_each) {
dim_index = 0;
}
index = 0;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -211,6 +211,7 @@ from .tensor.math import remainder # noqa: F401
from .tensor.math import mod # noqa: F401
from .tensor.math import floor_mod # noqa: F401
from .tensor.math import multiply # noqa: F401
from .tensor.math import renorm # noqa: F401
from .tensor.math import add # noqa: F401
from .tensor.math import subtract # noqa: F401
from .tensor.math import logsumexp # noqa: F401
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.set_device('cpu')
class TestRenormAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
[[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]])
self.p = 1.0
self.dim = 2
self.max_norm = 2.05
def test_renorm_api(self):
paddle.enable_static()
self.input_data()
# case 1:
with program_guard(Program(), Program()):
#x = fluid.layers.data(name = 'x',shape=[-1, 2, 3])
x = paddle.static.data(name="x", shape=[-1, 2, 3], dtype='float64')
z = paddle.renorm(x, self.p, self.dim, self.max_norm)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={"x": self.data_x},
fetch_list=[z],
return_numpy=False)
expected = np.array([[[0.40594056, 0.29285714, -0.41000000],
[0.60891086, 0.04392857, 0.61500001]],
[[0.40594056, -1.17142856, 0.41000000],
[0.62920785, 0.54178572, 0.61500001]]])
self.assertTrue(np.allclose(expected, np.array(res)))
def test_dygraph_api(self):
self.input_data()
# case axis none
with fluid.dygraph.guard():
input = [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]]
x = paddle.to_tensor(input, stop_gradient=False)
y = paddle.renorm(x, 1.0, 2, 2.05)
expected = np.array([[[0.40594056, 0.29285714, -0.41000000],
[0.60891086, 0.04392857, 0.61500001]],
[[0.40594056, -1.17142856, 0.41000000],
[0.62920785, 0.54178572, 0.61500001]]])
self.assertTrue(np.allclose(expected, np.array(y)))
z = paddle.mean(y)
z.backward(retain_graph=True)
expected_grad = np.array(
[[[0, 0.01394558, 0.02733333], [0, 0.01394558, 0.00683333]],
[[0, 0.01045918, 0.00683333], [0, 0.01394558, 0.00683333]]])
self.assertTrue(np.allclose(expected_grad, np.array(x.grad)))
#test exception:
with fluid.dygraph.guard():
input = [[[2.0, 2, -2], [3, 0.3, 3]], [[2, -8, 2], [3.1, 3.7, 3]]]
x = paddle.to_tensor(input, stop_gradient=False)
exp = False
try:
paddle.renorm(x, 1.0, 8, 2.05)
except:
exp = True
self.assertTrue(exp)
exp = False
try:
paddle.renorm(x, 1.0, -4, 2.05)
except:
exp = True
self.assertTrue(exp)
y = paddle.renorm(x, 1.0, -1, 2.05)
expected = np.array([[[0.40594056, 0.29285714, -0.41000000],
[0.60891086, 0.04392857, 0.61500001]],
[[0.40594056, -1.17142856, 0.41000000],
[0.62920785, 0.54178572, 0.61500001]]])
self.assertTrue(np.allclose(expected, np.array(y)))
if __name__ == '__main__':
unittest.main()
......@@ -1194,6 +1194,62 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
type="addmm", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
def renorm(x, p, axis, max_norm):
"""
**renorm**
This operator is used to calculate the p-norm along the axis,
suppose the input-shape on axis dimension has the value of T, then
the tensor is split into T parts, the p-norm should be calculated for each
part, if the p-norm for part i is larger than max-norm, then each element
in part i should be re-normalized at the same scale so that part-i' p-norm equals
max-norm exactly, otherwise part-i stays unchanged.
Args:
x (Tensor): The input Tensor
p (float): The power of the norm operation.
axis (int): the dimension to slice the tensor.
max-norm (float): the maximal norm limit.
Returns:
Tensor: the renorm Tensor.
Examples:
.. code-block:: python
import paddle
input = [[[2.0,2,-2],[3,0.3,3]],[[2,-8,2],[3.1,3.7,3]]]
x = paddle.to_tensor(input,dtype='float32')
y = paddle.renorm(x, 1.0, 2, 2.05)
print(y)
# [[[ 0.40594056, 0.29285714, -0.41000000],
# [ 0.60891086, 0.04392857, 0.61500001]],
# [[ 0.40594056, -1.17142856, 0.41000000],
# [ 0.62920785, 0.54178572, 0.61500001]]])
"""
input_shape = x.shape
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'renorm')
if not axis < len(input_shape):
raise ValueError("the axis:{} should be less then the shape's size {}:{}".format(axis,len(input_shape),input_shape))
if not axis >=0:
if not axis >= -1 * len(input_shape):
raise ValueError("the axis:{} should not be less than -1 * length of input_shape:{}".format(axis,-1 * len(input_shape)))
axis = axis + len(input_shape)
if in_dygraph_mode():
out = core.ops.renorm(x, 'p',p, 'axis',axis, 'max_norm', max_norm)
return out
inputs = {'X': x}
attrs = {'p': p, 'axis': axis, 'max_norm':max_norm}
helper = LayerHelper("renorm", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="renorm", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
def inner(x, y, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册