未验证 提交 6984fbca 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] support dy2static for IPU merge code (#43770)

* feat(): dynamic_to_static support for ipu.

* fix(): format fix.

* fix format

* fix cpplint error

* use phi::errors

* fix format

* fix format

* fix(): add api to restore patched function.

* fix(): identity_loss uses cpu place as expected kernel type.

* doc(): add IPU dy2static related docs.

* fix(): combine test cases.

* fix format

* fix comment

* fix format

* apply comment

* fix compiling

* fix(): align docs.

* fix(): fix identity_loss function docs.

* fix(): adjust mean and sum in identity_loss.

* fix(): minor docs.

* move API to paddle.incubate.identity_loss

* fix UT
Co-authored-by: Nzhaorui chen <zhaoruic@graphcore.ai>
上级 2afa9b76
......@@ -46,6 +46,10 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
execution_strategy.num_threads_ = 1;
break;
}
case platform::DeviceType::IPU: {
execution_strategy.num_threads_ = 1;
break;
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
device_type));
......
......@@ -37,6 +37,9 @@ void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
ipu_rt_op_desc.SetInput("FeedList", feed_list);
ipu_rt_op_desc.SetOutput("FetchList", fetch_list);
ipu_rt_op_desc.Flush();
// set op_role to avoid program.clone failure
ipu_rt_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
{static_cast<int>(framework::OpRole::kForward)});
// Create a new node for the ipu_runtime_op.
auto* ipu_rt_node = graph->CreateOpNode(&ipu_rt_op_desc);
......
......@@ -287,6 +287,19 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
} else if (op_role == OpRole::kLRSched) {
// op_role == OpRole::kLRSched | OpRole::kOptimize
new_op.SetAttr("with_lr_sched", true);
} else if (op_type == "identity_loss") {
auto outputs = op->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(),
1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
new_op.SetAttr("loss_var", loss_var);
}
}
......
......@@ -548,6 +548,15 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."));
#endif
} else if (platform::is_ipu_place(place)) {
#if defined(PADDLE_WITH_IPU)
gc.reset(new IPUGarbageCollector(place, max_memory_size));
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use IPU device since it's not compiled with IPU,"
"Please recompile or reinstall Paddle with IPU support."));
#endif
} else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
......
......@@ -394,6 +394,16 @@ PreparedOp PrepareImpl(
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_IPU
if (kernel_iter == kernels.end() &&
paddle::platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_MLU
if (kernel_iter == kernels.end() &&
paddle::platform::is_mlu_place(expected_kernel_key.place_)) {
......
......@@ -140,6 +140,15 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use NPU device since it's not compiled with NPU,"
"Please recompile or reinstall Paddle with NPU support."));
#endif
} else if (platform::is_ipu_place(place)) {
#if defined(PADDLE_WITH_IPU)
gc.reset(new framework::IPUGarbageCollector(place, 0));
VLOG(10) << "Created GarbageCollector at " << place;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use IPU device since it's not compiled with IPU,"
"Please recompile or reinstall Paddle with IPU support."));
#endif
} else if (platform::is_mlu_place(place)) {
#if defined(PADDLE_WITH_MLU)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class IdentityLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
class IdentityLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of identity_loss op");
AddOutput("Out", "(Tensor) The output of identity_loss op");
AddAttr<int>("reduction", "(int, default 1). The reduction.")
.SetDefault(1)
.InEnum({0, 1, 2});
AddComment(R"DOC(
IdentityLoss Operator mark the Loss var.
)DOC");
}
};
class IdentityLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, platform::CPUPlace());
}
};
template <typename T>
class IdentityLossGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("identity_loss_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(IdentityLossInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(IdentityLossGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(identity_loss,
IdentityLossInferShapeFunctor,
PD_INFER_META(phi::IdentityLossInferMeta));
REGISTER_OPERATOR(identity_loss,
ops::IdentityLossOp,
ops::IdentityLossOpMaker,
ops::IdentityLossGradMaker<paddle::framework::OpDesc>,
ops::IdentityLossGradMaker<paddle::imperative::OpBase>,
ops::IdentityLossInplaceInferer,
IdentityLossInferShapeFunctor);
REGISTER_OPERATOR(identity_loss_grad,
ops::IdentityLossGradOp,
ops::IdentityLossGradInplaceInferer);
......@@ -535,7 +535,9 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->loss_var = resources_->tensors[loss_var];
resources_->with_lr_sched =
BOOST_GET_CONST(bool, op_desc->GetAttr("with_lr_sched"));
if (op_desc->HasAttr("lr_var")) {
if (ipu_strategy_->is_dynamic) {
resources_->lr = ipu_strategy_->lr;
} else if (op_desc->HasAttr("lr_var")) {
auto lr_var = BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var"));
resources_->lr_var = lr_var;
resources_->lr = GetSingleVarFromScope<float>(scope, lr_var);
......
......@@ -213,8 +213,13 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
optimizer = compiler_resources_->eval_optimizer.get();
} else {
VLOG(10) << "Update learning_rate";
auto new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
float new_lr;
if (ipu_strategy_->is_dynamic) {
new_lr = ipu_strategy_->lr;
} else {
new_lr =
GetSingleVarFromScope<float>(scope_, compiler_resources_->lr_var);
}
VLOG(10) << "New Lr: " << new_lr;
optimizer = compiler_resources_->UpdateOptimizer(new_lr);
}
......
......@@ -101,6 +101,10 @@ IpuStrategy::IpuStrategy() {
ADD_STRING_OPTION(onnx_dump_path);
ADD_STRING_OPTION(weight_decay_mode);
// dy2static support
ADD_DOUBLE_OPTION(lr);
ADD_BOOL_OPTION(is_dynamic);
#undef ADD_STRING_OPTION
#undef ADD_DOUBLE_OPTION
#undef ADD_UINT64_OPTION
......
......@@ -112,6 +112,12 @@ class IpuStrategy {
// Custom ops
std::vector<IpuCustomOpIdentifier> custom_ops;
// lr for dynamic2static
float lr = 0.0;
// whether in dynamic mode
bool is_dynamic = false;
public:
void AddBoolOption(const std::string &option, bool value);
void AddUint64Option(const std::string &option, std::uint64_t value);
......
......@@ -85,6 +85,17 @@ Node *identity_handler(Graph *graph, Node *node) {
graph, node, "popart_identity", node->inputs, node->outputs);
}
Node *identity_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction"));
return CreateBaseOp(graph,
node,
"popart_identity_loss",
node->inputs,
node->outputs,
{{"reduction", reduction}});
}
Node *detach_handler(Graph *graph, Node *node) {
return CreateBaseOp(
graph, node, "popart_detach_v2", node->inputs, node->outputs);
......@@ -101,4 +112,5 @@ REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);
REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler);
REGISTER_HANDLER(identity, identity_handler);
REGISTER_HANDLER(identity_loss, identity_loss_handler);
REGISTER_HANDLER(detach, detach_handler);
......@@ -17,5 +17,6 @@
#pragma once
OP_DECL(popart_nllloss_v2, aiGraphcoreOpset.nllloss, SIG_ARG(INT32,popart::ReductionType,reduction) OPT_ARG(INT32,ignoreIndex) ARG(BOOL,inputIsLogProbability) ) // NOLINT
OP_DECL(popart_identity_loss, aiGraphcoreOpset.identityloss, SIG_ARG(INT32,popart::ReductionType,reduction) ) // NOLINT
// clang-format on
......@@ -123,6 +123,8 @@ DeviceType Place2DeviceType(const platform::Place& place) {
return platform::DeviceType::CUDA;
} else if (platform::is_xpu_place(place)) {
return platform::DeviceType::XPU;
} else if (platform::is_ipu_place(place)) {
return platform::DeviceType::IPU;
} else if (platform::is_mlu_place(place)) {
return platform::DeviceType::MLU;
} else {
......
......@@ -142,6 +142,8 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
return place_obj.cast<platform::CUDAPinnedPlace>();
} else if (py::isinstance<platform::NPUPlace>(place_obj)) {
return place_obj.cast<platform::NPUPlace>();
} else if (py::isinstance<platform::IPUPlace>(place_obj)) {
return place_obj.cast<platform::IPUPlace>();
} else if (py::isinstance<platform::Place>(place_obj)) {
return place_obj.cast<platform::Place>();
} else if (py::isinstance<platform::MLUPlace>(place_obj)) {
......@@ -151,8 +153,8 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of "
"Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/MLUPlace/"
"CustomPlace"));
"Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/IPUPlace/"
"MLUPlace/CustomPlace"));
}
}
......@@ -198,6 +200,8 @@ static void InitVarBaseAndTensor(imperative::VarBase *self,
tensor, array, place, zero_copy);
} else if (platform::is_npu_place(place)) {
SetTensorFromPyArray<platform::NPUPlace>(tensor, array, place, zero_copy);
} else if (platform::is_ipu_place(place)) {
SetTensorFromPyArray<platform::IPUPlace>(tensor, array, place, zero_copy);
} else if (platform::is_mlu_place(place)) {
SetTensorFromPyArray<platform::MLUPlace>(tensor, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
......@@ -206,7 +210,8 @@ static void InitVarBaseAndTensor(imperative::VarBase *self,
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/MLUPlace"));
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/IPUPlace/"
"MLUPlace"));
}
self->SetDataType(framework::TransToProtoVarType(tensor->dtype()));
}
......@@ -1856,6 +1861,18 @@ void BindImperative(py::module *m_ptr) {
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::IPUPlace &place,
bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def(
"_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
......@@ -2140,6 +2157,11 @@ void BindImperative(py::module *m_ptr) {
self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::IPUPlace>(obj)) {
auto p = obj.cast<platform::IPUPlace *>();
self.SetExpectedPlace(*p);
VLOG(4) << "Tracer(" << &self << ")"
<< " set expected place " << *p;
} else if (py::isinstance<platform::MLUPlace>(obj)) {
auto p = obj.cast<platform::MLUPlace *>();
self.SetExpectedPlace(*p);
......@@ -2158,7 +2180,7 @@ void BindImperative(py::module *m_ptr) {
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Incompatible Place Type: supports XPUPlace, CUDAPlace, "
"CPUPlace, NPUPlace, MLUPlace"
"CPUPlace, NPUPlace, IPUPlace, MLUPlace"
"and CUDAPinnedPlace, "
"but got Unknown Type!"));
}
......@@ -2313,6 +2335,28 @@ void BindImperative(py::module *m_ptr) {
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
const PyNameVarBaseMap &ins,
const PyNameVarBaseMap &outs,
framework::AttributeMap attrs,
const platform::IPUPlace &place,
bool trace_backward,
const std::map<std::string, std::string> &inplace_map = {}) {
auto ins_map = ConvertToNameVarBaseMap(ins);
auto outs_map = ConvertToNameVarBaseMap(outs);
{
py::gil_scoped_release release;
self.TraceOp<imperative::VarBase>(type,
std::move(ins_map),
std::move(outs_map),
std::move(attrs),
place,
trace_backward,
inplace_map);
}
})
.def("trace",
[](imperative::Tracer &self,
const std::string &type,
......
......@@ -3262,6 +3262,18 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
out->set_dims(output_dims);
}
void IdentityLossInferMeta(const MetaTensor& x,
int reduction,
MetaTensor* out) {
if (reduction == 2) {
out->set_dtype(x.dtype());
out->set_dims(x.dims());
} else {
out->set_dims(phi::make_ddim({1}));
out->set_dtype(x.dtype());
}
}
} // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
......
......@@ -469,4 +469,6 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
const std::string& data_format,
MetaTensor* out);
void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/identity_loss_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/mean_all_grad_kernel.h"
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void IdentityLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const int reduction,
DenseTensor* x_grad) {
switch (reduction) {
case 0:
// sum
phi::ReduceSumGradKernel<T>(
dev_ctx, x, out_grad, std::vector<int64_t>{0}, false, true, x_grad);
break;
case 1:
// mean
phi::MeanAllGradKernel<T>(dev_ctx, x, out_grad, x_grad);
break;
case 2:
// none
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
break;
default:
// error
PADDLE_THROW(phi::errors::InvalidArgument(
"reduction should be 0, 1 and 2. But get %d", reduction));
}
}
} // namespace phi
PD_REGISTER_KERNEL(identity_loss_grad,
CPU,
ALL_LAYOUT,
phi::IdentityLossGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/identity_loss_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/mean_all_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename T, typename Context>
void IdentityLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const int reduction,
DenseTensor* out) {
switch (reduction) {
case 0:
// sum
phi::SumRawKernel<T>(
dev_ctx, x, std::vector<int64_t>{0}, false, true, out->dtype(), out);
break;
case 1:
// mean
phi::MeanAllKernel<T>(dev_ctx, x, out);
break;
case 2:
// none
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
break;
default:
// error
PADDLE_THROW(phi::errors::InvalidArgument(
"reduction should be 0, 1 and 2. But get %d", reduction));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
identity_loss, CPU, ALL_LAYOUT, phi::IdentityLossKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void IdentityLossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const int reduction,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void IdentityLossKernel(const Context& dev_ctx,
const DenseTensor& x,
const int reduction,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IdentityLossOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("identity_loss", {"X"}, {"reduction"}, {"Out"});
}
KernelSignature IdentityLossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"identity_loss_grad", {"X", "Out@GRAD"}, {"reduction"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(identity_loss, phi::IdentityLossOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(identity_loss_grad,
phi::IdentityLossGradOpArgumentMapping);
......@@ -506,6 +506,192 @@ class CompiledProgram(object):
return place_list
class IpuDynamicPatcher(object):
"""
Patcher for IPU dynamic2static support.
"""
patcher_cache = []
def __init__(self):
pass
@staticmethod
def convert_concrete_program(ipu_strategy,
concrete_program,
class_instance=None):
"""
Convert the ConcreteProgram to IPUConcreteProgram.
"""
from ..fluid.dygraph.base import switch_to_static_graph
from ..fluid import backward
from ..fluid.initializer import Constant
from ..fluid.framework import device_guard
import paddle
inputs = concrete_program.inputs
outputs = concrete_program.outputs
startup_program = concrete_program.startup_program
scope = paddle.static.global_scope()
@switch_to_static_graph
def append_backward_desc():
program = concrete_program.main_program
# backward with optimizer to add backward graph to program
backward.gradients_with_optimizer(program, ipu_strategy._optimizer)
# initialize backward parameters
exe = paddle.static.Executor(paddle.CPUPlace())
startup_program = paddle.static.default_startup_program()
exe.run(startup_program)
return program
if ipu_strategy.enable_fp16:
class_instance.to(dtype="float16")
# copy the bias and filters
for param_or_buffer in concrete_program.parameters:
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# TODO(czr): feed and fetch list needs to consider more type
if class_instance:
feed_list = [elem.name for elem in inputs[1:] if elem is not None]
else:
feed_list = [elem.name for elem in inputs if elem is not None]
fetch_list = [elem.name for elem in outputs]
if ipu_strategy.is_training:
concrete_program.main_program = append_backward_desc()
# copy optimizer parameters
optimizer = ipu_strategy._optimizer
for k, v in optimizer._accumulators.items():
for param_name, var_tmp in v.items():
var = optimizer.helper.create_global_variable(
name=var_tmp.name,
persistable=True,
dtype=var_tmp.dtype,
type=var_tmp.type,
shape=var_tmp.shape,
belong_to_optimizer=True)
device = optimizer._get_device_for_param(param_name)
with device_guard(device):
optimizer.helper.set_variable_initializer(
var, initializer=Constant(value=0.0))
param_or_lr_tensor = scope.find_var(
var_tmp.name).get_tensor()
optim_tensor = var.value().get_tensor()
param_or_lr_tensor._share_data_with(optim_tensor)
optimizer._accumulators[k][param_name] = var
@switch_to_static_graph
def func_compile():
if ipu_strategy.enable_fp16:
amp_list = paddle.static.amp.CustomOpLists()
amp_list.unsupported_list = {"cumsum"}
to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
concrete_program.main_program,
amp_list,
use_fp16_guard=False)
paddle.static.amp.cast_parameters_to_fp16(
paddle.CPUPlace(),
concrete_program.main_program,
to_fp16_var_names=to_fp16_var_names)
program = IpuCompiledProgram(concrete_program.main_program,
ipu_strategy=ipu_strategy,
scope=scope).compile(
feed_list, fetch_list)
return program
main_program = func_compile()
concrete_program.main_program = main_program
return concrete_program
@staticmethod
def patch_program_cache(ipu_strategy):
""" Monkey patch ProgramCache discriptor to support dynamic2static in IPU.
Args:
ipu_strategy: The ipu_strategy used in dynamic graph.
Returns:
None
"""
from ..fluid.dygraph.dygraph_to_static.program_translator import ProgramCache
from ..fluid.dygraph.dygraph_to_static.program_translator import CacheKey
from ..fluid.dygraph.dygraph_to_static import logging_utils
from ..fluid.dygraph.dygraph_to_static.program_translator import MAX_TRACED_PROGRAM_COUNT
from ..fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
old_getter = ProgramCache.__getitem__
def patch_getter(self, item):
if not isinstance(item, CacheKey):
raise ValueError(
'type(item) should be CacheKey, but received %s' %
type_name(item))
item_id = hash(item)
self._recent_key = item_id
if item_id not in self._caches or ipu_strategy.need_compile:
if item_id in self._caches:
logging_utils.warn(
"ipu_strategy chances detected. Please sync weights.")
if self._caches and not ipu_strategy.need_compile:
logging_utils.warn(
"dynamic2static on IPU doesn't support mutiple caches. Please make sure"
"dynamic inputs is not used.")
concrete_program, _ = self._build_once(item)
concrete_program = IpuDynamicPatcher.convert_concrete_program(
ipu_strategy, concrete_program, item.class_instance)
self._caches[item_id] = (concrete_program,
partial_program_from(concrete_program))
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
logging_utils.warn(
"Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors."
.format(current_tracing_count,
MAX_TRACED_PROGRAM_COUNT))
return self._caches[item_id]
setattr(ProgramCache, '__getitem__', patch_getter)
IpuDynamicPatcher.patcher_cache.append(
[ProgramCache, '__getitem__', old_getter])
@staticmethod
def patch_lr_scheduler(ipu_strategy):
from paddle.optimizer.lr import LRScheduler
# For IPU dynamic graph usage, lr_var is not synced in executor as static mode do.
# Manually set lr to ipu_strategy to update the lr.
old_step = LRScheduler.step
def patch_step(self, epoch=None):
old_step(self, epoch)
ipu_strategy.set_options({"lr": self.last_lr})
setattr(LRScheduler, 'step', patch_step)
IpuDynamicPatcher.patcher_cache.append([LRScheduler, 'step', old_step])
@staticmethod
def register_patch(ipu_strategy):
IpuDynamicPatcher.patch_program_cache(ipu_strategy)
IpuDynamicPatcher.patch_lr_scheduler(ipu_strategy)
@staticmethod
def release_patch():
for module, key, attr in IpuDynamicPatcher.patcher_cache:
setattr(module, key, attr)
class IpuStrategy(object):
"""
Help users precisely control the graph building in :code:`paddle.static.IpuCompiledProgram` .
......@@ -542,10 +728,121 @@ class IpuStrategy(object):
self._ipu_strategy.set_options(default_options)
self.has_custom_ops = False
self.custom_op_names = []
self.need_compile = True
else:
raise RuntimeError(
"Can not use IpuStrategy in non IPU compiled environment, please re-compile with WITH_IPU=ON."
)
from paddle import in_dynamic_mode
if in_dynamic_mode():
self.register_patch()
def register_patch(self):
"""
Register patchs function to support dynamic to static on IPU. This operation would break the dy2static functionality on CPU.
Use `release_patch` to release the patch.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
ipu_strategy = static.IpuStrategy()
ipu_strategy.register_patch()
"""
IpuDynamicPatcher.register_patch(self)
def release_patch(self):
"""
Release the registered IPU functions.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
ipu_strategy = static.IpuStrategy()
ipu_strategy.release_patch()
"""
IpuDynamicPatcher.release_patch()
def set_optimizer(self, optimizer):
"""
Set optimizer to ipu_strategy in dynamic mode.
Args:
optimizer (Optimizer): Optimizer to be used in training.
Returns:
None.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
linear = paddle.nn.Linear(10, 10)
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=linear.parameters())
ipu_strategy = static.IpuStrategy()
ipu_strategy.set_optimizer(optimizer)
"""
from paddle import in_dynamic_mode
if in_dynamic_mode():
self._optimizer = optimizer
optimizer_attrs = self.parse_optimizer(optimizer)
self._ipu_strategy.set_options(optimizer_attrs)
else:
raise RuntimeError("Only needs to set optimizer in dynamic mode.")
def parse_optimizer(self, optimizer):
"""
Parse optimizer attributes for IPU dynamic to static support. Currently only support parse lr.
Args:
optimizer (Optimizer): Optimizer to be parsed.
Returns:
Dict.
Examples:
.. code-block:: python
# required: ipu
import paddle
import paddle.static as static
linear = paddle.nn.Linear(10, 10)
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=linear.parameters())
ipu_strategy = static.IpuStrategy()
attrs = ipu_strategy.parse_optimizer(optimizer)
"""
def get_lr():
from paddle.optimizer.lr import LRScheduler
if isinstance(optimizer._learning_rate, float):
return {"lr": optimizer._learning_rate}
elif isinstance(optimizer._learning_rate, LRScheduler):
return {"lr": optimizer._learning_rate()}
attr_fn = [get_lr]
optimizer_attrs = {"is_dynamic": True}
for fn in attr_fn:
optimizer_attrs.update(fn())
return optimizer_attrs
def set_graph_config(self,
num_ipus=1,
......@@ -743,6 +1040,10 @@ class IpuStrategy(object):
ipu_strategy.set_options(options)
"""
self._ipu_strategy.set_options(options)
# check whether to recompile program with updated ipu options.
recompile_white_list = {'lr'}
if options.keys() - recompile_white_list:
self.need_compile = True
def get_option(self, option):
"""
......@@ -1050,4 +1351,6 @@ class IpuCompiledProgram(object):
if not hasattr(program, 'org_program'):
program.org_program = self._program
self._ipu_strategy.need_compile = False
return program
......@@ -1230,6 +1230,63 @@ def softmax_with_cross_entropy(logits,
return_softmax, axis)
def identity_loss(x, reduction="none"):
r"""Marks a tensor as being part of the loss calculation for IPU.
This operator is used to handle on the (final) loss of a model so that
it is used as the start of backpropagation.
When `reduction` is `none`, return raw `Out`.
When `reduction` is `mean`, return
.. math::
Out = MEAN(Out)
When `reduction` is `sum`, return
.. math::
Out = SUM(Out)
Parameters:
x (Variable): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of
additional dimensions. It's data type should be float32, float64 on CPU and float16, float32 on IPU.
reduction(str|int, optional): Reduce the loss output. Supported string values are: 'sum', 'mean', 'none'
the corresponding int values are 0, 1, 2 respectively. The default value is "none".
Returns:
Variable: The loss ``Tensor`` with the specified reduction applied.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
loss = fluid.data(name="loss", shape=[-1, 1], dtype="float32")
out = paddle.incubate.identity_loss(loss, reduction=1)
"""
if isinstance(reduction, str):
reduction = {"sum": 0, "mean": 1, "none": 2}.get(reduction.lower())
if reduction is None:
raise Exception("Unsupported reduction type.")
if _non_static_mode():
return _C_ops.identity_loss(x, "reduction", reduction)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "identity_loss")
attrs = {'reduction': reduction}
helper = LayerHelper('identity_loss', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="identity_loss",
inputs={"X": x},
outputs={"Out": out},
attrs=attrs)
return out
def rank_loss(label, left, right, name=None):
r"""
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
import os
import paddle
import paddle.fluid as fluid
from paddle.jit import to_static
from paddle.utils.cpp_extension import load
from paddle.optimizer.lr import LRScheduler
import tempfile
SEED = 2022
class SimpleLayer(paddle.nn.Layer):
def __init__(self, use_ipu=False):
super(SimpleLayer, self).__init__()
self.use_ipu = use_ipu
self.conv = paddle.nn.Conv2D(in_channels=3,
out_channels=1,
kernel_size=2,
stride=1)
def forward(self, x, target=None):
x = self.conv(x)
x = paddle.fluid.layers.flatten(x, axis=1)
if target is not None:
x = paddle.fluid.layers.softmax(x)
loss = paddle.fluid.layers.cross_entropy(x, target)
if self.use_ipu:
loss = paddle.incubate.identity_loss(loss, 1)
else:
loss = paddle.mean(loss)
return x, loss
return x
class TestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.disable_static()
cls.save_path = tempfile.TemporaryDirectory()
@classmethod
def tearDownClass(cls):
cls.save_path.cleanup()
def _test(self, use_ipu=False):
paddle.seed(SEED)
np.random.seed(SEED)
model = SimpleLayer(use_ipu)
specs = [
paddle.static.InputSpec(name="x",
shape=[32, 3, 10, 10],
dtype="float32"),
paddle.static.InputSpec(name="target", shape=[32], dtype="int64"),
]
model = paddle.jit.to_static(model, input_spec=specs)
optim = paddle.optimizer.Adam(learning_rate=0.01,
parameters=model.parameters())
data = paddle.uniform((32, 3, 10, 10), dtype='float32')
label = paddle.randint(0, 10, shape=[32], dtype='int64')
model_path = '{}/model_state_dict_{}.pdparams'.format(
self.save_path, 'ipu' if use_ipu else 'cpu')
optim_path = '{}/optim_state_dict_{}.pdopt'.format(
self.save_path, 'ipu' if use_ipu else 'cpu')
if use_ipu:
device = paddle.set_device('ipu')
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
micro_batch_size=1,
enable_manual_shard=False)
ipu_strategy.set_precision_config(enable_fp16=True)
ipu_strategy.set_optimizer(optim)
data = data.astype(np.float16)
result = []
for epoch in range(100):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(data, label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
if use_ipu:
paddle.fluid.core.IpuBackend.get_instance().weights_to_host()
paddle.save(model.state_dict(), model_path)
paddle.save(optim.state_dict(), optim_path)
model.set_state_dict(paddle.load(model_path))
optim.set_state_dict(paddle.load(optim_path))
for epoch in range(100):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(data, label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
return np.array(result)
def test_training(self):
cpu_loss = self._test(False).flatten()
ipu_loss = self._test(True).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-2))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
from paddle.jit import to_static
from paddle.utils.cpp_extension import load
from paddle.optimizer.lr import LRScheduler
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramCache
import tempfile
SEED = 2022
class SimpleLayer(paddle.nn.Layer):
def __init__(self, use_ipu=False):
super(SimpleLayer, self).__init__()
self.use_ipu = use_ipu
self.conv = paddle.nn.Conv2D(in_channels=3,
out_channels=1,
kernel_size=2,
stride=1)
@to_static()
def forward(self, x, target=None):
x = self.conv(x)
x = paddle.fluid.layers.flatten(x, axis=1)
if target is not None:
x = paddle.fluid.layers.softmax(x)
loss = paddle.fluid.layers.cross_entropy(x, target)
if self.use_ipu:
loss = paddle.incubate.identity_loss(loss, 1)
else:
loss = paddle.mean(loss)
return x, loss
return x
class TestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.disable_static()
def _test(self, use_ipu=False):
paddle.seed(SEED)
np.random.seed(SEED)
model = SimpleLayer(use_ipu)
optim = paddle.optimizer.Adam(learning_rate=0.01,
parameters=model.parameters())
data = paddle.uniform((32, 3, 10, 10), dtype='float32')
label = paddle.randint(0, 10, shape=[32], dtype='int64')
if use_ipu:
device = paddle.set_device('ipu')
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
micro_batch_size=1,
enable_manual_shard=False)
ipu_strategy.set_optimizer(optim)
result = []
for epoch in range(100):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(data, label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
if use_ipu:
ipu_strategy.release_patch()
return np.array(result)
def test_training(self):
ipu_loss = self._test(True).flatten()
cpu_loss = self._test(False).flatten()
self.assertTrue(np.allclose(ipu_loss, cpu_loss, atol=1e-4))
class TestSaveLoad(TestBase):
@classmethod
def setUpClass(cls):
paddle.disable_static()
cls.save_path = tempfile.TemporaryDirectory()
@classmethod
def tearDownClass(cls):
cls.save_path.cleanup()
def _test(self, use_ipu=False):
paddle.seed(SEED)
np.random.seed(SEED)
model = SimpleLayer(use_ipu)
optim = paddle.optimizer.Adam(learning_rate=0.01,
parameters=model.parameters())
data = paddle.uniform((32, 3, 10, 10), dtype='float32')
label = paddle.randint(0, 10, shape=[32], dtype='int64')
model_path = '{}/model_state_dict_{}.pdparams'.format(
self.save_path, 'ipu' if use_ipu else 'cpu')
optim_path = '{}/optim_state_dict_{}.pdopt'.format(
self.save_path, 'ipu' if use_ipu else 'cpu')
if use_ipu:
device = paddle.set_device('ipu')
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1,
is_training=True,
micro_batch_size=1,
enable_manual_shard=False)
ipu_strategy.set_optimizer(optim)
result = []
for epoch in range(100):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(data, label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
if use_ipu:
paddle.fluid.core.IpuBackend.get_instance().weights_to_host()
paddle.save(model.state_dict(), model_path)
paddle.save(optim.state_dict(), optim_path)
model.set_state_dict(paddle.load(model_path))
optim.set_state_dict(paddle.load(optim_path))
for epoch in range(100):
# ipu only needs call model() to do forward/backward/grad_update
pred, loss = model(data, label)
if not use_ipu:
loss.backward()
optim.step()
optim.clear_grad()
result.append(loss)
if use_ipu:
ipu_strategy.release_patch()
return np.array(result)
class TestPatch(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.disable_static()
def test(self, use_ipu=False):
old_getter = ProgramCache.__getitem__
old_step = LRScheduler.step
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.release_patch()
reset_getter = ProgramCache.__getitem__
reset_step = LRScheduler.step
self.assertTrue(reset_getter is old_getter)
self.assertTrue(reset_step is old_step)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import paddle.optimizer
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import (IPUOpTest,
np_dtype_to_fluid_str)
from paddle.utils.cpp_extension import load
paddle.enable_static()
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_feed()
self.set_feed_attr()
self.set_op()
def set_op(self):
# setup custom op
self.op = paddle.incubate.identity_loss
def set_feed(self):
self.feed = {
"x": np.random.uniform(low=-2, high=2, size=[3,
5]).astype('float32'),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed.values()]
self.feed_list = list(self.feed.keys())
self.feed_dtype = [
np_dtype_to_fluid_str(x.dtype) for x in self.feed.values()
]
def _test_base(self, reduction):
scope = fluid.core.Scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
SEED = 0
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
with fluid.scope_guard(scope):
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype=self.feed_dtype[0])
out = self.op(x, reduction)
fetch_list = [out.name]
place = paddle.IPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
feed_list = self.feed_list
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(num_ipus=1, is_training=False)
ipu_compiler = compiler.IpuCompiledProgram(
main_prog, ipu_strategy=ipu_strategy)
program = ipu_compiler.compile(feed_list, fetch_list)
ipu_res = exe.run(program, self.feed, fetch_list)
if reduction == 0:
# sum
cpu_res = self.feed['x'].sum()
elif reduction == 1:
# mean
cpu_res = self.feed['x'].mean()
else:
# none
cpu_res = self.feed['x']
self.assertTrue(np.allclose(ipu_res[0], cpu_res, atol=self.atol))
def test_base(self):
# TODO: use string instead of int for reduction
for reduction in [0, 1, 2]:
self._test_base(reduction)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
class TestIdentityLossOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.python_api = paddle.incubate.identity_loss
self.inputs = {}
self.initTestCase()
self.dtype = np.float64
self.op_type = "identity_loss"
self.attrs = {}
self.attrs['reduction'] = self.reduction
input = np.random.random(self.shape).astype(self.dtype)
self.inputs['X'] = input
if self.reduction == 0:
output = input.sum()
elif self.reduction == 1:
output = input.mean()
else:
output = input
self.outputs = {'Out': output}
def test_check_output(self):
paddle.enable_static()
self.check_output(check_eager=True)
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True)
paddle.disable_static()
def initTestCase(self):
self.shape = (4, 10, 10)
self.reduction = 0
class TestCase1(TestIdentityLossOp):
def initTestCase(self):
self.shape = (8, 16, 8)
self.reduction = 0
class TestCase2(TestIdentityLossOp):
def initTestCase(self):
self.shape = (8, 16)
self.reduction = 1
class TestCase3(TestIdentityLossOp):
def initTestCase(self):
self.shape = (4, 8, 16)
self.reduction = 2
class TestIdentityLossFloat32(TestIdentityLossOp):
def set_attrs(self):
self.dtype = 'float32'
class TestIdentityLossOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
input_data = np.random.random((2, 4)).astype("float32")
def test_int():
paddle.incubate.identity_loss(x=input_data, reduction=3)
self.assertRaises(Exception, test_int)
def test_string():
paddle.incubate.identity_loss(x=input_data,
reduction="wrongkey")
self.assertRaises(Exception, test_string)
def test_dtype():
x2 = fluid.layers.data(name='x2', shape=[1], dtype='int32')
paddle.incubate.identity_loss(x=x2, reduction=1)
self.assertRaises(TypeError, test_dtype)
paddle.disable_static()
class TestIdentityLossAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.place = fluid.CPUPlace()
def identity_loss_ref(self, input, reduction):
if reduction == 0 or reduction == "sum":
return input.sum()
elif reduction == 1 or reduction == "mean":
return input.mean()
else:
return input
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_shape)
out1 = paddle.incubate.identity_loss(x)
out2 = paddle.incubate.identity_loss(x, reduction=0)
out3 = paddle.incubate.identity_loss(x, reduction=1)
out4 = paddle.incubate.identity_loss(x, reduction=2)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x},
fetch_list=[out1, out2, out3, out4])
ref = [
self.identity_loss_ref(self.x, 2),
self.identity_loss_ref(self.x, 0),
self.identity_loss_ref(self.x, 1),
self.identity_loss_ref(self.x, 2)
]
for out, out_ref in zip(res, ref):
self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True)
def test_api_dygraph(self):
paddle.disable_static(self.place)
def test_case(x, reduction):
x_tensor = paddle.to_tensor(x)
out = paddle.incubate.identity_loss(x_tensor, reduction)
out_ref = self.identity_loss_ref(x, reduction)
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-04),
True)
test_case(self.x, 0)
test_case(self.x, 1)
test_case(self.x, 2)
test_case(self.x, "sum")
test_case(self.x, "mean")
test_case(self.x, "none")
paddle.enable_static()
def test_errors(self):
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 12]).astype('float32')
x = paddle.to_tensor(x)
self.assertRaises(Exception, paddle.incubate.identity_loss, x, -1)
self.assertRaises(Exception, paddle.incubate.identity_loss, x, 3)
self.assertRaises(Exception, paddle.incubate.identity_loss, x,
"wrongkey")
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12], 'int32')
self.assertRaises(TypeError, paddle.incubate.identity_loss, x)
if __name__ == '__main__':
unittest.main()
......@@ -35,6 +35,8 @@ from . import sparse #noqa: F401
from . import nn #noqa: F401
from . import asp #noqa: F401
from ..fluid.layers.loss import identity_loss
from ..fluid.incubate import fleet
__all__ = [
......@@ -50,4 +52,5 @@ __all__ = [
'segment_mean',
'segment_max',
'segment_min',
'identity_loss',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册