未验证 提交 f43b5fe5 编写于 作者: J jameszhang 提交者: GitHub

kunlun support c_softmax_with_cross_entropy (#49934)

* kunlun support c_softmax_with_cross_entropy

* fix grad calc error

* replace mutable_data() and ShareDataWith()

* update xdnn

* update xpu toolchain to 20230215

* remove fluid from test file
上级 605242a8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/cross_entropy.h"
#include "paddle/phi/kernels/funcs/softmax_impl.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
#include "paddle/phi/kernels/xpu/reduce.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CSoftmaxWithCrossEntropyOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
CSoftmaxWithCrossEntropyProcessGroupFunctor<DeviceContext, T> functor_;
functor_(ctx);
} else {
CSoftmaxWithCrossEntropyFunctor<DeviceContext, T> functor_;
functor_(ctx);
}
}
};
template <typename T>
struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const phi::DenseTensor* logits = ctx.Input<phi::DenseTensor>("Logits");
const phi::DenseTensor* labels = ctx.Input<phi::DenseTensor>("Label");
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(rid);
distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::MAX;
// allocate memory on device.
dev_ctx.template Alloc(softmax, logits->dtype());
dev_ctx.template Alloc(loss, logits->dtype());
const auto& logits_dims = logits->dims();
const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
phi::DenseTensor logits_2d, softmax_2d;
framework::TensorCopy(
*logits, ctx.GetPlace(), ctx.device_context(), &logits_2d);
framework::TensorCopy(
*softmax, ctx.GetPlace(), ctx.device_context(), &softmax_2d);
logits_2d.Resize({N, D});
softmax_2d.Resize({N, D});
int ret = -1;
// step 1, obtain logit_max
phi::DenseTensor logits_max;
logits_max = ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
{
// reduce last dim
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
false,
false,
&logits_max,
f);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_max");
}
std::vector<phi::DenseTensor> in_out;
in_out.push_back(logits_max);
pg->AllReduce(in_out, in_out, opts)->Synchronize();
// step 2, obtain logit - logit_max
{
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_sub<XPUType>(ctx, x, y, z, xshape, yshape);
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
}
// step 3, obtain predict target
phi::DenseTensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
const int start_index = rank * D;
const int end_index = start_index + D;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
ret = xpu::mask_label_by_index<XPUType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
labels->data<int32_t>(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
start_index,
end_index,
N,
D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
labels->data<int64_t>(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
start_index,
end_index,
N,
D,
nranks);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index");
in_out.clear();
in_out.push_back(predicted_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize();
// step 4, obtain exp(logit)
ret = xpu::exp<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "exp");
// step 5, obtain sum_exp_logits
phi::DenseTensor sum_exp_logits;
sum_exp_logits = ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
false,
false,
&sum_exp_logits,
f);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_max");
}
in_out.clear();
in_out.push_back(sum_exp_logits);
opts.reduce_op = distributed::ReduceOp::SUM;
pg->AllReduce(in_out, in_out, opts)->Synchronize();
int dims[4] = {N, D, N, 1};
ret = xpu::broadcast_div<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
std::vector<int64_t>(dims, dims + 2),
std::vector<int64_t>(dims + 2, dims + 4));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast_div");
ret = xpu::log<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<XPUType*>(sum_exp_logits.data<T>()),
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "log");
ret = xpu::sub<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<const XPUType*>(predicted_logits.data<T>()),
reinterpret_cast<XPUType*>(loss->data<T>()),
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sub");
framework::TensorCopy(
softmax_2d, ctx.GetPlace(), ctx.device_context(), softmax);
}
};
template <typename T>
struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
using XPUType = typename XPUTypeTrait<T>::Type;
const phi::DenseTensor* logits = ctx.Input<phi::DenseTensor>("Logits");
const phi::DenseTensor* labels = ctx.Input<phi::DenseTensor>("Label");
phi::DenseTensor* softmax = ctx.Output<phi::DenseTensor>("Softmax");
phi::DenseTensor* loss = ctx.Output<phi::DenseTensor>("Loss");
const int rid = ctx.Attr<int>("ring_id");
const int nranks = ctx.Attr<int>("nranks");
const int rank = ctx.Attr<int>("rank");
const auto& place = ctx.GetPlace();
const auto& comm = platform::BKCLCommContext::Instance().Get(rid, place);
auto& dev_ctx = ctx.template device_context<phi::XPUContext>();
// use global calculate stream
const auto stream = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
// allocate memory on device.
dev_ctx.template Alloc(softmax, logits->dtype());
dev_ctx.template Alloc(loss, logits->dtype());
const auto& logits_dims = logits->dims();
const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
phi::DenseTensor logits_2d, softmax_2d;
framework::TensorCopy(
*logits, ctx.GetPlace(), ctx.device_context(), &logits_2d);
framework::TensorCopy(
*softmax, ctx.GetPlace(), ctx.device_context(), &softmax_2d);
logits_2d.Resize({N, D});
softmax_2d.Resize({N, D});
int ret = -1;
// step 1, obtain logit_max
phi::DenseTensor logits_max;
logits_max = ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
void* logits_max_buff = logits_max.data<T>();
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
dev_ctx,
logits_2d,
std::vector<int64_t>(dims, dims + 1),
false,
false,
&logits_max,
f);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_max");
}
PADDLE_ENFORCE_XPU_SUCCESS(
bkcl_all_reduce(comm->comm(),
logits_max_buff,
logits_max_buff,
logits_max.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(logits_max.dtype())),
BKCL_MAX,
stream));
xpu_wait(stream);
// step 2, obtain logit - logit_max
{
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_sub<XPUType>(ctx, x, y, z, xshape, yshape);
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
}
// step 3, obtain predict target
phi::DenseTensor predicted_logits;
predicted_logits =
ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
void* predict_logits_buff = predicted_logits.data<T>();
ret = xpu::constant<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
N,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant");
const int start_index = rank * D;
const int end_index = start_index + D;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
ret = xpu::mask_label_by_index<XPUType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
labels->data<int32_t>(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
start_index,
end_index,
N,
D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
labels->data<int64_t>(),
reinterpret_cast<XPUType*>(predicted_logits.data<T>()),
start_index,
end_index,
N,
D,
nranks);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index");
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_reduce(
comm->comm(),
predict_logits_buff,
predict_logits_buff,
predicted_logits.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(predicted_logits.dtype())),
BKCL_ADD,
stream));
xpu_wait(stream);
// step 4, obtain exp(logit)
ret = xpu::exp<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "exp");
// step 5, obtain sum_exp_logits
phi::DenseTensor sum_exp_logits;
sum_exp_logits = ctx.AllocateTmpTensor<T, phi::XPUContext>({N, 1}, dev_ctx);
{
int dims[1] = {1};
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
};
ret = phi::XPUReduce<phi::XPUContext, XPUType>(
dev_ctx,
softmax_2d,
std::vector<int64_t>(dims, dims + 1),
false,
false,
&sum_exp_logits,
f);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_max");
}
void* sum_exp_logits_buff = sum_exp_logits.data<T>();
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_reduce(
comm->comm(),
sum_exp_logits_buff,
sum_exp_logits_buff,
sum_exp_logits.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(sum_exp_logits.dtype())),
BKCL_ADD,
stream));
xpu_wait(stream);
{
int dims[4] = {N, D, N, 1};
ret = xpu::broadcast_div<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
std::vector<int64_t>(dims, dims + 2),
std::vector<int64_t>(dims + 2, dims + 4));
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast_div");
}
ret = xpu::log<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<XPUType*>(sum_exp_logits.data<T>()),
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "log");
ret = xpu::sub<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(sum_exp_logits.data<T>()),
reinterpret_cast<const XPUType*>(predicted_logits.data<T>()),
reinterpret_cast<XPUType*>(loss->data<T>()),
N * 1);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sub");
framework::TensorCopy(
softmax_2d, ctx.GetPlace(), ctx.device_context(), softmax);
}
};
template <typename DeviceContext, typename T>
class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using XPUType = typename XPUTypeTrait<T>::Type;
const phi::DenseTensor* labels = context.Input<phi::DenseTensor>("Label");
const phi::DenseTensor* loss_grad =
context.Input<phi::DenseTensor>(framework::GradVarName("Loss"));
phi::DenseTensor* logit_grad =
context.Output<phi::DenseTensor>(framework::GradVarName("Logits"));
const phi::DenseTensor* softmax =
context.Input<phi::DenseTensor>("Softmax");
const int rank = context.Attr<int>("rank");
auto& dev_ctx = context.template device_context<DeviceContext>();
if (logit_grad != softmax) {
framework::TensorCopy(
*softmax, context.GetPlace(), context.device_context(), logit_grad);
}
const auto softmax_dims = softmax->dims();
const int axis = softmax_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, softmax_dims);
const int D = phi::funcs::SizeFromAxis(axis, softmax_dims);
const int start_index = rank * D;
const int end_index = start_index + D;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
int ret = 0;
if (label_type == framework::proto::VarType::INT32) {
ret = xpu::mask_label_by_index_grad<XPUType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad->data<T>()),
labels->data<int32_t>(),
reinterpret_cast<XPUType*>(logit_grad->data<T>()),
start_index,
end_index,
N,
D);
} else if (label_type == framework::proto::VarType::INT64) {
ret = xpu::mask_label_by_index_grad<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad->data<T>()),
labels->data<int64_t>(),
reinterpret_cast<XPUType*>(logit_grad->data<T>()),
start_index,
end_index,
N,
D);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "mask_label_by_index_grad");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(c_softmax_with_cross_entropy,
ops::CSoftmaxWithCrossEntropyOp<phi::XPUContext, float>);
REGISTER_OP_XPU_KERNEL(
c_softmax_with_cross_entropy_grad,
ops::CSoftmaxWithCrossEntropyGrad<phi::XPUContext, float>);
......@@ -99,6 +99,9 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32})},
{"c_reduce_sum",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_split",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import sys
import numpy as np
from test_collective_base_xpu import (
DataTypeCast,
TestCollectiveRunnerBase,
runtime_main,
)
import paddle
from paddle.framework import core
from paddle.static import Executor, Program, data, program_guard
paddle.enable_static()
class TestCollectiveSoftmaxWithCE(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
self.batch_size = 10
self.num_class = 1000
self.nranks = 2
self.ring_id = 0
self.local_elements = int(self.num_class / self.nranks)
def get_model(self, main_prog, startup_program, rank):
with program_guard(main_prog, startup_program):
logits = data(
name="Logits",
shape=[self.batch_size, self.local_elements],
dtype='float32',
)
label = data(
name="Label", shape=[self.batch_size, 1], dtype='int32'
)
softmax = main_prog.current_block().create_var(
name="Softmax",
dtype=logits.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
loss = main_prog.current_block().create_var(
name="Loss",
dtype=logits.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
loss_grad = main_prog.current_block().create_var(
name="Loss@GRAD",
shape=[self.batch_size, 1],
dtype=logits.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
block = main_prog.global_block()
with paddle.static.device_guard("xpu"):
c_softmax_with_ce_op = block.append_op(
type="c_softmax_with_cross_entropy",
inputs={'Logits': logits, 'Label': label},
outputs={'Softmax': softmax, 'Loss': loss},
attrs={
'ring_id': self.ring_id,
'rank': rank,
'nranks': self.nranks,
},
)
# generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
c_softmax_with_ce_op.desc, set(), []
)
for grad_op_desc in grad_op_desc_list:
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
main_prog._sync_with_cpp()
return loss, softmax
def run_trainer(self, args):
train_prog = Program()
startup_prog = Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
self.initCommunicator(
startup_prog, rank, self.nranks, True, current_endpoint, endpoints
)
np_data_type = DataTypeCast(args["data_type"])
loss, softmax = self.get_model(train_prog, startup_prog, rank)
device_id = int(os.getenv("FLAGS_selected_xpus", "0"))
place = paddle.XPUPlace(device_id)
exe = Executor(place)
exe.run(startup_prog)
# NOTE use uid here to assure that two xpus share the same label
np.random.seed(os.getuid())
label = np.random.randint(
0,
self.num_class,
size=(self.batch_size, 1),
dtype='int32',
)
# use FAKE loss_grad here, only to examine the correctness of grad func
loss_grad = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, 1)
).astype(np_data_type)
# each xpu uses own half of logits
np.random.seed(os.getpid())
logits = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, self.local_elements)
).astype(np_data_type)
out = exe.run(
train_prog,
feed={'Logits': logits, 'Label': label, 'Loss@GRAD': loss_grad},
fetch_list=[loss.name, softmax.name, 'Logits@GRAD'],
)
sys.stdout.buffer.write(pickle.dumps(out))
if __name__ == "__main__":
os.environ["BKCL_PCIE_RING"] = "1"
runtime_main(TestCollectiveSoftmaxWithCE, "softmax_with_ce", 0)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import numpy as np
from test_collective_base_xpu import DataTypeCast, TestDistBase
import paddle
from paddle.framework import core
sys.path.append("..")
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx = (x - np.max(x)).clip(-64.0)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
if soft_label:
return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
axis_dim = shape[axis]
remain = int(np.prod(shape[axis + 1 :]))
softmax_reshape = softmax.reshape((n, axis_dim, remain))
label_reshape = label.reshape((n, 1, remain))
result = np.zeros_like(label_reshape, dtype=softmax.dtype)
for i in range(n):
for j in range(remain):
lbl = label_reshape[i, 0, j]
if lbl != ignore_index:
result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j])
return result.reshape(label.shape)
def softmax_with_cross_entropy_grad(softmax, label, loss_grad, axis):
logit_grad = softmax.copy()
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
d = int(np.prod(shape[axis:]))
for i in range(n * d):
row = int(i / d)
col = i % d
if col == label[row]:
logit_grad[row][col] = (logit_grad[row][col] - 1.0) * loss_grad[row]
else:
logit_grad[row][col] = logit_grad[row][col] * loss_grad[row]
return logit_grad
class XPUTestCSoftmaxWithCEOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'c_softmax_with_cross_entropy'
self.use_dynamic_create_class = False
class TestCSoftmaxWithCEOp(TestDistBase):
def _setup_config(self):
pass
def test_softmax_with_ce(self):
self.batch_size = 10
self.num_class = 1000
self.check_with_place(
"collective_softmax_with_cross_entropy_op_xpu.py",
"softmax_with_ce",
self.in_type_str,
)
def check_with_place(
self,
model_file,
col_type,
data_type,
check_error_log=False,
need_envs={},
):
required_envs = {
"FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "0",
"DATA_TYPE": data_type,
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
np_data_type = DataTypeCast(data_type)
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs
)
# get data that is shared by both ranks
np.random.seed(os.getuid())
label = np.random.randint(
0, self.num_class, size=(self.batch_size, 1), dtype='int32'
)
loss_grad = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, 1)
).astype(np_data_type)
local_elements = int(self.num_class / 2)
# get input data for rank 0
np.random.seed(pid0)
input0 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
).astype(np_data_type)
# get input data for rank 1
np.random.seed(pid1)
input1 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
).astype(np_data_type)
# get combined input data
inputs = np.concatenate((input0, input1), axis=1)
# calculate analytic result
need_softmax = np.apply_along_axis(stable_softmax, 1, inputs)
need_loss = cross_entropy(need_softmax, label, False, 1)
need_logits_grad = softmax_with_cross_entropy_grad(
need_softmax, label, loss_grad, axis=1
)
# get real result
loss0, softmax0, logits_grad0 = tr0_out
loss1, softmax1, logits_grad1 = tr1_out
softmax = np.concatenate((softmax0, softmax1), axis=1)
logits_grad = np.concatenate((logits_grad0, logits_grad1), axis=1)
# compare results
rtol = 1e-6
np.testing.assert_allclose(loss0, need_loss, rtol=rtol)
np.testing.assert_allclose(loss1, need_loss, rtol=rtol)
np.testing.assert_allclose(softmax, need_softmax, rtol=rtol)
np.testing.assert_allclose(logits_grad, need_logits_grad, rtol=rtol)
support_types = get_xpu_op_support_types('c_softmax_with_cross_entropy')
for stype in support_types:
create_test_class(
globals(),
XPUTestCSoftmaxWithCEOP,
stype,
ignore_device_version=[core.XPUVersion.XPU1],
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册