未验证 提交 889bdde3 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Phi] migrate exponential kernel to phi (#44376)

* [Phi] migrate exponential kernel to phi

* fix comment

* fix CI
上级 99bf7007
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/exponential_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp");
auto dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -51,52 +46,6 @@ exponential distribution.
}
};
class ExponentialOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
template <typename T>
class ExponentialKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
int64_t size = out->numel();
auto gen = framework::DefaultCPUGenerator();
auto engine = gen->GetCPUEngine();
std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
}
};
class ExponentialGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out_Grad",
"ExponentialGradOp");
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};
template <typename T>
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -104,10 +53,10 @@ class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("exponential_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
retv->SetType("fill_any_like");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetAttr("value", 0.0f);
retv->SetOutput("Out", this->InputGrad("X"));
}
};
......@@ -118,24 +67,15 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});
DECLARE_INFER_SHAPE_FUNCTOR(exponential,
ExponentialInfershapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(exponential,
ops::ExponentialOp,
ops::ExponentialOpMaker,
ops::ExponentialOpInferVarType,
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>,
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>,
ExponentialInferer);
REGISTER_OPERATOR(exponential_grad,
ops::ExponentialGradOp,
ExponentialGradInferer);
REGISTER_OP_CPU_KERNEL(exponential,
ops::ExponentialKernel<phi::CPUContext, float>,
ops::ExponentialKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(exponential_grad,
ops::ExponentialGradKernel<phi::CPUContext, float>,
ops::ExponentialGradKernel<phi::CPUContext, double>);
ExponentialInferer,
ExponentialInfershapeFunctor);
......@@ -46,7 +46,8 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
'const paddle::optional<Tensor>&': 'const MetaTensor&'
}
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
wrapped_infermeta_name = get_wrapped_infermeta_name(
api.kernel['func'][0])
args = []
for input_name in api.inputs['names']:
if input_name in kernel_params:
......
......@@ -689,6 +689,17 @@
func : expm1
backward : expm1_grad
- api : exponential_
args : (Tensor x, float lambda)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : exponential
inplace : (x -> out)
backward : exponential__grad
- api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out)
......
......@@ -720,6 +720,15 @@
func : expm1_grad
inplace : (out_grad -> x_grad)
- backward_api : exponential__grad
forward : exponential_ (Tensor x, float lambda) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
inplace : (out_grad -> x_grad)
- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/exponential_kernel.h"
#include <random>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace phi {
template <typename T, typename Context>
void ExponentialKernel(const Context& dev_ctx,
const DenseTensor& x,
float lambda,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto engine = dev_ctx.GetGenerator()->GetCPUEngine();
std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < out->numel(); ++i) {
out_data[i] = trans(uniform(*engine));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
exponential, CPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out);
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,31 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/exponential_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename DeviceContext, typename T>
class ExponentialKernel;
template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out) {
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
}
template <typename DeviceContext, typename T>
class ExponentialGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> functor;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
functor(dev_ctx, dx, static_cast<T>(0));
}
};
} // namespace phi
} // namespace operators
} // namespace paddle
PD_REGISTER_KERNEL(
exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,37 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/exponential_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ExponentialKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto& dev_cxt = ctx.template device_context<platform::CUDADeviceContext>();
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_cxt, out, dist, trans);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
exponential,
ops::ExponentialKernel<plat::CUDADeviceContext, float>,
ops::ExponentialKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
exponential_grad,
ops::ExponentialGradKernel<plat::CUDADeviceContext, float>,
ops::ExponentialGradKernel<plat::CUDADeviceContext, double>);
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ExponentialOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("exponential", {"X"}, {"lambda"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(exponential, phi::ExponentialOpArgumentMapping);
......@@ -18,13 +18,13 @@ import numpy as np
from op_test import OpTest
import os
paddle.enable_static()
paddle.seed(100)
class TestExponentialOp1(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "exponential"
self.config()
......@@ -87,8 +87,14 @@ class TestExponentialAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
x = paddle.full([10, 10], -1., dtype='float32')
x.exponential_(0.5)
self.assertTrue(np.min(x.numpy()) >= 0)
x.stop_gradient = False
y = 2 * x
y.exponential_(0.5)
print(y)
self.assertTrue(np.min(y.numpy()) >= 0)
y.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np.zeros([10, 10])))
paddle.enable_static()
def test_fixed_random_number(self):
......
......@@ -1052,7 +1052,9 @@ def exponential_(x, lam=1.0, name=None):
# [0.72520673, 0.45208144, 0.30234432]]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_exponential_(x, lam)
elif paddle.in_dynamic_mode():
return _C_ops.exponential_(x, "lambda", lam)
check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册