未验证 提交 902c6f98 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Fix c_split op for TensorParallel (#33207)

* fix c_split bug

* fix utest

* add c_embedding for tensorparallel
上级 2c10ca64
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_embedding_op.h"
namespace paddle {
namespace operators {
class CEmbeddingOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CEmbeddingOp");
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "CEmbeddingOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CEmbeddingOp");
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'c_embedding' must be 2. "
"But received c_embedding's dimensions = %d, "
"c_embedding's shape = [%s].",
table_dims.size(), table_dims));
auto output_dims = framework::vectorize(ids_dims);
output_dims.push_back(table_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Ids", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddInput("Ids",
"An input with type int64 "
"contains the ids to be looked up in W.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<int64_t>("start_index",
"(int64, default 0), The starting index is indeed, "
"and the out-of-bounds will be set to 0 ")
.SetDefault(0);
AddComment(R"DOC(
c_embedding Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CEmbeddingGradOpNoBufferVarsInferer, "W");
template <typename T>
class CEmbeddingGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("c_embedding_grad");
op->SetInput("W", this->Input("W"));
op->SetInput("Ids", this->Input("Ids"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetAttrMap(this->Attrs());
}
};
class CEmbeddingOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class CEmbeddingOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = framework::GradVarName("W");
VLOG(3) << "c_embedding_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_embedding, ops::CEmbeddingOp, ops::CEmbeddingOpMaker,
ops::CEmbeddingGradOpMaker<paddle::framework::OpDesc>,
ops::CEmbeddingGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(c_embedding_grad, ops::CEmbeddingOpGrad,
ops::CEmbeddingGradOpNoBufferVarsInferer,
ops::CEmbeddingOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(c_embedding, ops::CEmbeddingOpCPUKernel<float>,
ops::CEmbeddingOpCPUKernel<double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_embedding_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T, typename IndexT>
__global__ void CEmbedding(T *out, const T *table, const IndexT *ids,
const int rows, const int columns, const int64_t N,
const int64_t start_idx, const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N, real_idx);
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
}
}
}
template <typename T, typename IndexT>
__global__ void CEmbeddingGrad(T *table, const T *output, const IndexT *ids,
const int rows, const int columns,
const int64_t N, const int64_t start_idx,
const int64_t end_idx, const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
paddle::platform::CudaAtomicAdd(&table[real_idx * columns + col],
output[i]);
}
}
}
template <typename T>
class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<LoDTensor>("W");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *output_t = context.Output<LoDTensor>("Out");
const auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
const int64_t end_idx = start_idx + N;
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const auto &index_type = ids_t->type();
if (index_type == framework::proto::VarType::INT32) {
CEmbedding<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
output, table, ids_t->data<int32_t>(), K, D, N, start_idx, end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbedding<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
output, table, ids_t->data<int64_t>(), K, D, N, start_idx, end_idx,
limit);
}
}
};
template <typename T>
class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t end_idx = start_idx + N;
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = ids_t->type();
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids_t->data<int32_t>(), K, D, N, start_idx,
end_idx, limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids_t->data<int64_t>(), K, D, N, start_idx,
end_idx, limit);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_embedding, ops::CEmbeddingCUDAKernel<float>,
ops::CEmbeddingCUDAKernel<double>,
ops::CEmbeddingCUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(c_embedding_grad, ops::CEmbeddingGradCUDAKernel<float>,
ops::CEmbeddingGradCUDAKernel<double>,
ops::CEmbeddingGradCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
class CEmbeddingOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_embedding for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -45,6 +45,12 @@ class CSplitOp : public framework::OperatorWithKernel {
rank, nranks));
framework::DDim dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(
dim[dim.size() - 1] % nranks, 0,
platform::errors::InvalidArgument("The last dimension (%d) of the X "
"should be divisible by nranks (%d)",
dim[dim.size() - 1], nranks));
dim[dim.size() - 1] = dim[dim.size() - 1] / nranks;
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
......
......@@ -16,10 +16,38 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_split_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void SplitFromRank(const T* input, T* output, const int rows,
const int columns, const int rank,
const int nranks, const int limit) {
CUDA_KERNEL_LOOP(i, limit) {
int row = i / columns;
int col = i % columns;
int block = columns / nranks;
int start = block * rank;
int end = start + block;
if (col >= start && col < end) {
int idx = block * row + col % block;
output[idx] = input[i];
}
}
}
template <typename T>
class CSplitOpCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -47,24 +75,25 @@ class CSplitOpCUDAKernel : public framework::OpKernel<T> {
rank, nranks));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> shape_refer;
std::vector<framework::Tensor*> results;
size_t numel = x->numel();
auto dims = x->dims();
numel /= nranks;
int axis = dims.size() - 1;
dims[dims.size() - 1] /= nranks;
for (int i = 0; i < nranks; i++) {
framework::Tensor* out = new framework::Tensor();
out->mutable_data<T>(dims, place);
shape_refer.emplace_back(out);
results.emplace_back(out);
}
auto dims_size = dims.size();
// final dim
int64_t end_size = dims[dims_size - 1];
math::SplitFunctor<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, *x, shape_refer, axis, &results);
// remain dim
auto remain_ddim = framework::slice_ddim(dims, 0, dims_size - 1);
int64_t remain_numel = framework::product(remain_ddim);
int limit = x->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
dims[dims_size - 1] /= nranks;
out->mutable_data<T>(dims, place);
paddle::framework::TensorCopySync(*results[rank], out->place(), out);
SplitFromRank<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
x->data<T>(), out->data<T>(), remain_numel, end_size, rank, nranks,
limit);
}
};
} // namespace operators
......
......@@ -781,7 +781,7 @@ def _c_identity(tensor, group=None):
return out
def _c_concat(tensor, nranks, group=None):
def _c_concat(tensor, group=None):
"""
Return allgather of the tensor, mainly used with model parallel.
......@@ -797,10 +797,14 @@ def _c_concat(tensor, nranks, group=None):
return
ring_id = 0 if group is None else group.id
global_rank = _get_global_env().rank
rank = global_rank if group is None else group.get_group_rank(global_rank)
nranks = _get_global_env().world_size if group is None else group.nranks
if in_dygraph_mode():
return core.ops.c_concat(tensor, 'ring_id', ring_id, 'use_calc_stream',
True, 'nranks', nranks, 'use_model_parallel',
True)
True, 'rank', rank, 'nranks', nranks,
'use_model_parallel', True)
op_type = 'c_concat'
helper = LayerHelper(op_type, **locals())
......@@ -818,12 +822,13 @@ def _c_concat(tensor, nranks, group=None):
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
'nranks': nranks
'nranks': nranks,
'rank': rank
})
return out
def _c_split(tensor, rank, nranks, group=None):
def _c_split(tensor, group=None):
"""
Split tensor evenly among all members, mainly used with model parallel.
......@@ -840,6 +845,10 @@ def _c_split(tensor, rank, nranks, group=None):
return
ring_id = 0 if group is None else group.id
global_rank = _get_global_env().rank
rank = global_rank if group is None else group.get_group_rank(global_rank)
nranks = _get_global_env().world_size if group is None else group.nranks
if in_dygraph_mode():
return core.ops.c_split(tensor, 'use_calc_stream', True, 'ring_id',
ring_id, 'rank', rank, 'nranks', nranks,
......@@ -889,6 +898,24 @@ def _mp_allreduce(tensor,
raise NotImplementedError("No support _mp_allreduce in dygraph mode.")
def _c_lookup_table(table, index, start_index=0, name=None):
"""
Lookup table according to index.
Args:
table (Tensor): The input Tensor. Its data type
should be float16, float32, float64.
index (Tensor): The index to lookup table.
start_index (int): The initial index for table range.
name (string): The name of the api
Returns:
Tensor.
"""
if in_dygraph_mode():
return core.ops.c_embedding(table, index, "start_index", start_index)
class _Linear(layers.Layer):
"""
Linear
......@@ -995,7 +1022,7 @@ def _parallel_linear(x,
if axis == 0:
if split_tensor:
x = _c_split(x, inner_rank, nranks, group=group)
x = _c_split(x, group=group)
else:
x = _c_identity(x, group=group)
......
......@@ -43,14 +43,13 @@ class VocabParallelEmbedding(Layer):
self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
per_part_size = (
num_embeddings + self.world_size - 1) // self.world_size
last_part_size = num_embeddings - per_part_size * (self.world_size - 1)
if self.rank == self.world_size - 1:
per_part_size = last_part_size
per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size
assert num_embeddings % self.world_size == 0, (
"The length of the vocabulary must be divisible by the parallelism degree of MP"
)
per_part_size = num_embeddings // self.world_size
self.vocab_start_index = self.rank * per_part_size
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
......@@ -63,49 +62,35 @@ class VocabParallelEmbedding(Layer):
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight[per_part_size - 1] = 0.0
self.weight.is_distributed = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=[num_embeddings, embedding_dim],
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
def forward(self, x):
if not self.is_mp:
return F.embedding(
if self.is_mp:
output_parallel = paddle.distributed.collective._c_lookup_table(
self.weight,
x,
start_index=self.vocab_start_index,
name=self._name)
output = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output = F.embedding(
x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)
origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
else:
assert origin_input_shape[-1] == 1, (
"The last dimension size of x must be 1.")
x_shard = paddle.shard_index(x, self.origin_num_embeddings,
self.world_size, self.rank,
self.per_part_size - 1)
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)
emb_out = F.embedding(
x_shard,
weight=self.weight,
padding_idx=self.per_part_size - 1,
sparse=False,
name=self._name)
emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
return emb_out
return output
class ColumnParallelLinear(Layer):
......@@ -175,9 +160,7 @@ class ColumnParallelLinear(Layer):
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel,
nranks=self.world_size,
group=self.model_parallel_group)
output_parallel, group=self.model_parallel_group)
else:
output = output_parallel
return output
......@@ -245,10 +228,7 @@ class RowParallelLinear(Layer):
else:
# split last dim
input_parallel = paddle.distributed.collective._c_split(
x,
rank=self.rank,
nranks=self.world_size,
group=self.model_parallel_group)
x, group=self.model_parallel_group)
output_parallel = F.linear(input_parallel, self.weight, name=self._name)
......
......@@ -80,6 +80,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_split)
LIST(REMOVE_ITEM TEST_OPS test_allgather)
LIST(REMOVE_ITEM TEST_OPS test_c_identity)
LIST(REMOVE_ITEM TEST_OPS test_c_embedding_op)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce)
......
......@@ -212,7 +212,7 @@ class TestDistTraning(unittest.TestCase):
optimizer_b.step()
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-5)
loss_a.numpy(), loss_b.numpy(), rtol=5e-6)
def test_parallel_embedding(self):
batch_size = 17
......@@ -265,8 +265,9 @@ class TestDistTraning(unittest.TestCase):
optimizer_a.step()
optimizer_b.step()
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-6)
print(loss_a.numpy(), loss_b.numpy())
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
if __name__ == '__main__':
......
......@@ -32,14 +32,36 @@ def set_random_seed(seed, dp_id, rank_id):
paddle.seed(seed + rank_id)
vocab_size = 5
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 2
output_size = 10
seq_length = 2
batch_size = 4
def parallel_matmul(lm_output, logit_weights, parallel_output):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
world_size = hcg.get_model_parallel_world_size()
rank = hcg.get_model_parallel_rank()
if world_size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=model_parallel_group)
logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(
logits, group=model_parallel_group)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
class SimpleMPNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2, mp_id):
......@@ -86,6 +108,7 @@ class SimpleMPNet(fluid.dygraph.Layer):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = parallel_matmul(x, self.embedding.weight, False)
return x
......@@ -128,6 +151,7 @@ class SimpleDPNet(fluid.dygraph.Layer):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
......@@ -192,7 +216,7 @@ class TestDistMPTraning(unittest.TestCase):
loss_b = self.train_batch(batch, model_b, optimizer_b, False)
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-5)
loss_a.numpy(), loss_b.numpy(), rtol=1e-6)
if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.framework import core
def get_c_embedding(start, end, table, ids):
index = ids.flatten()
input_mask = (index < start) | (index >= end)
masked_input = index - start
masked_input[input_mask] = 0
output = table[masked_input]
output[input_mask] = 0.0
return output
class TestCEmbeddingOp(OpTest):
def setUp(self):
self.op_type = "c_embedding"
table = np.random.random((17, 31)).astype("float64")
ids = np.random.randint(
low=0, high=17 * 2, size=(2, 4, 5)).astype("int32")
self.start_index = 10
self.end_index = self.start_index + 17
self.inputs = {'W': table, 'Ids': ids}
np_out = get_c_embedding(self.start_index, self.end_index, table, ids)
self.outputs = {'Out': np_out.reshape((2, 4, 5, 31))}
self.attrs = {'start_index': self.start_index}
def test_check_output_gpu(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_gpu(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
if __name__ == "__main__":
unittest.main()
......@@ -711,4 +711,5 @@ STATIC_MODE_TESTING_LIST = [
'test_model_cast_to_bf16',
'test_sgd_op_bf16',
'test_marker_op',
'test_c_embedding_op',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册