未验证 提交 889bdde3 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Phi] migrate exponential kernel to phi (#44376)

* [Phi] migrate exponential kernel to phi

* fix comment

* fix CI
上级 99bf7007
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/exponential_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel { ...@@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp");
auto dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -51,52 +46,6 @@ exponential distribution. ...@@ -51,52 +46,6 @@ exponential distribution.
} }
}; };
class ExponentialOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
template <typename T>
class ExponentialKernel<phi::CPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
int64_t size = out->numel();
auto gen = framework::DefaultCPUGenerator();
auto engine = gen->GetCPUEngine();
std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < size; ++i) {
out_data[i] = trans(uniform(*engine));
}
}
};
class ExponentialGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out_Grad",
"ExponentialGradOp");
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim);
}
};
template <typename T> template <typename T>
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> { class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -104,10 +53,10 @@ class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -104,10 +53,10 @@ class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> {
protected: protected:
void Apply(GradOpPtr<T> retv) const override { void Apply(GradOpPtr<T> retv) const override {
retv->SetType("exponential_grad"); retv->SetType("fill_any_like");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); retv->SetAttr("value", 0.0f);
retv->SetAttrMap(this->Attrs()); retv->SetOutput("Out", this->InputGrad("X"));
} }
}; };
...@@ -118,24 +67,15 @@ namespace ops = paddle::operators; ...@@ -118,24 +67,15 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer,
{paddle::framework::GradVarName("Out"), DECLARE_INFER_SHAPE_FUNCTOR(exponential,
paddle::framework::GradVarName("X")}); ExponentialInfershapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(exponential, REGISTER_OPERATOR(exponential,
ops::ExponentialOp, ops::ExponentialOp,
ops::ExponentialOpMaker, ops::ExponentialOpMaker,
ops::ExponentialOpInferVarType,
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>, ops::ExponentialGradOpMaker<paddle::framework::OpDesc>,
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>, ops::ExponentialGradOpMaker<paddle::imperative::OpBase>,
ExponentialInferer); ExponentialInferer,
REGISTER_OPERATOR(exponential_grad, ExponentialInfershapeFunctor);
ops::ExponentialGradOp,
ExponentialGradInferer);
REGISTER_OP_CPU_KERNEL(exponential,
ops::ExponentialKernel<phi::CPUContext, float>,
ops::ExponentialKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(exponential_grad,
ops::ExponentialGradKernel<phi::CPUContext, float>,
ops::ExponentialGradKernel<phi::CPUContext, double>);
...@@ -46,7 +46,8 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']} ...@@ -46,7 +46,8 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
'const paddle::optional<Tensor>&': 'const MetaTensor&' 'const paddle::optional<Tensor>&': 'const MetaTensor&'
} }
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) wrapped_infermeta_name = get_wrapped_infermeta_name(
api.kernel['func'][0])
args = [] args = []
for input_name in api.inputs['names']: for input_name in api.inputs['names']:
if input_name in kernel_params: if input_name in kernel_params:
......
...@@ -689,6 +689,17 @@ ...@@ -689,6 +689,17 @@
func : expm1 func : expm1
backward : expm1_grad backward : expm1_grad
- api : exponential_
args : (Tensor x, float lambda)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : exponential
inplace : (x -> out)
backward : exponential__grad
- api : eye - api : eye
args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={}) args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor(out) output : Tensor(out)
......
...@@ -720,6 +720,15 @@ ...@@ -720,6 +720,15 @@
func : expm1_grad func : expm1_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_api : exponential__grad
forward : exponential_ (Tensor x, float lambda) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
inplace : (out_grad -> x_grad)
- backward_api : flatten_grad - backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad) args : (Tensor xshape, Tensor out_grad)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/exponential_kernel.h"
#include <random>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace phi {
template <typename T, typename Context>
void ExponentialKernel(const Context& dev_ctx,
const DenseTensor& x,
float lambda,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
auto engine = dev_ctx.GetGenerator()->GetCPUEngine();
std::uniform_real_distribution<T> uniform(0.0, 1.0);
phi::funcs::exponential_transform<T> trans(lambda);
for (int64_t i = 0; i < out->numel(); ++i) {
out_data[i] = trans(uniform(*engine));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
exponential, CPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out);
} // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,31 +12,25 @@ ...@@ -12,31 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/kernels/exponential_kernel.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename Context>
class ExponentialKernel; void ExponentialKernel(const Context &dev_ctx,
const DenseTensor &x,
float lambda,
DenseTensor *out) {
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
}
template <typename DeviceContext, typename T> } // namespace phi
class ExponentialGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> functor;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
functor(dev_ctx, dx, static_cast<T>(0));
}
};
} // namespace operators PD_REGISTER_KERNEL(
} // namespace paddle exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,37 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,37 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/exponential_op.h" #include "paddle/phi/core/compat/op_utils.h"
namespace paddle { namespace phi {
namespace operators {
KernelSignature ExponentialOpArgumentMapping(
template <typename T> const ArgumentMappingContext& ctx) {
class ExponentialKernel<platform::CUDADeviceContext, T> return KernelSignature("exponential", {"X"}, {"lambda"}, {"Out"});
: public framework::OpKernel<T> { }
public:
void Compute(const framework::ExecutionContext& ctx) const override { } // namespace phi
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto& dev_cxt = ctx.template device_context<platform::CUDADeviceContext>(); PD_REGISTER_ARG_MAPPING_FN(exponential, phi::ExponentialOpArgumentMapping);
T lambda = static_cast<T>(ctx.Attr<float>("lambda"));
phi::funcs::uniform_distribution<T> dist;
phi::funcs::exponential_transform<T> trans(lambda);
phi::funcs::distribution_and_transform<T>(dev_cxt, out, dist, trans);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
exponential,
ops::ExponentialKernel<plat::CUDADeviceContext, float>,
ops::ExponentialKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
exponential_grad,
ops::ExponentialGradKernel<plat::CUDADeviceContext, float>,
ops::ExponentialGradKernel<plat::CUDADeviceContext, double>);
...@@ -18,13 +18,13 @@ import numpy as np ...@@ -18,13 +18,13 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import os import os
paddle.enable_static()
paddle.seed(100) paddle.seed(100)
class TestExponentialOp1(OpTest): class TestExponentialOp1(OpTest):
def setUp(self): def setUp(self):
paddle.enable_static()
self.op_type = "exponential" self.op_type = "exponential"
self.config() self.config()
...@@ -87,8 +87,14 @@ class TestExponentialAPI(unittest.TestCase): ...@@ -87,8 +87,14 @@ class TestExponentialAPI(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
x = paddle.full([10, 10], -1., dtype='float32') x = paddle.full([10, 10], -1., dtype='float32')
x.exponential_(0.5) x.stop_gradient = False
self.assertTrue(np.min(x.numpy()) >= 0) y = 2 * x
y.exponential_(0.5)
print(y)
self.assertTrue(np.min(y.numpy()) >= 0)
y.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np.zeros([10, 10])))
paddle.enable_static() paddle.enable_static()
def test_fixed_random_number(self): def test_fixed_random_number(self):
......
...@@ -1052,7 +1052,9 @@ def exponential_(x, lam=1.0, name=None): ...@@ -1052,7 +1052,9 @@ def exponential_(x, lam=1.0, name=None):
# [0.72520673, 0.45208144, 0.30234432]] # [0.72520673, 0.45208144, 0.30234432]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_exponential_(x, lam)
elif paddle.in_dynamic_mode():
return _C_ops.exponential_(x, "lambda", lam) return _C_ops.exponential_(x, "lambda", lam)
check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential") check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册