未验证 提交 8fa8a37f 编写于 作者: L lilong12 提交者: GitHub

add the c_identity op (#32485)

* add c_identity op, test=develop
上级 de947430
......@@ -321,6 +321,12 @@ class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<bool>(
"use_model_parallel",
"(bool default false) use this op with model parallel mode. In model "
"parallel mode, the backward is c_identity which returns itself for "
"c_allreduce_sum.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
CAllReduce %s Operator
......
......@@ -37,7 +37,12 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allreduce_sum");
bool use_mp = BOOST_GET_CONST(bool, this->GetAttr("use_model_parallel"));
if (use_mp) {
retv->SetType("c_identity");
} else {
retv->SetType("c_allreduce_sum");
}
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_identity_op.h"
namespace paddle {
namespace operators {
class CIdentityOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_identity");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_identity");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_identity must be non-negative.", ring_id));
framework::DDim dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class CIdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) identity tensor.");
AddOutput("Out", "(Tensor) identity tensor.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default true) eject CUDA operations to calculation stream.")
.SetDefault(true);
AddAttr<bool>("use_model_parallel",
"(bool default true) use this op with model parallel.")
.SetDefault(true);
AddComment(R"DOC(
Identity Operator which returns a copy of itself.
)DOC");
}
};
template <typename T>
class CIdentityOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allreduce_sum");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_identity, ops::CIdentityOp,
ops::CIdentityOpGradMaker<paddle::framework::OpDesc>,
ops::CIdentityOpGradMaker<paddle::imperative::OpBase>,
ops::CIdentityOpMaker);
REGISTER_OP_CPU_KERNEL(c_identity, ops::CIdentityOpCPUKernel<float>,
ops::CIdentityOpCPUKernel<double>,
ops::CIdentityOpCPUKernel<int>,
ops::CIdentityOpCPUKernel<int64_t>,
ops::CIdentityOpCPUKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_identity_op.h"
namespace paddle {
namespace operators {
template <typename T>
class CIdentityOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_identity op must be non-negative.", rid));
out->mutable_data<T>(ctx.GetPlace());
TensorCopy(*x, out->place(), out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_identity, ops::CIdentityOpCUDAKernel<float>,
ops::CIdentityOpCUDAKernel<double>,
ops::CIdentityOpCUDAKernel<int>,
ops::CIdentityOpCUDAKernel<int64_t>,
ops::CIdentityOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class CIdentityOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support c_identity for cpu kernel now."));
}
};
} // namespace operators
} // namespace paddle
......@@ -74,6 +74,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_concat)
LIST(REMOVE_ITEM TEST_OPS test_c_split)
LIST(REMOVE_ITEM TEST_OPS test_allgather)
LIST(REMOVE_ITEM TEST_OPS test_c_identity)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce)
......@@ -878,6 +879,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_c_concat PROPERTIES TIMEOUT 120)
set_tests_properties(test_c_split PROPERTIES TIMEOUT 120)
set_tests_properties(test_allgather PROPERTIES TIMEOUT 120)
set_tests_properties(test_c_identity PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_scatter_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_barrier_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_scatter PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveIdentity(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofgather",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_identity",
inputs={'X': tindata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id,
'nranks': nranks})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveIdentity, "identity", 0)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_base import TestDistBase
paddle.enable_static()
class TestIdentityOp(TestDistBase):
def _setup_config(self):
pass
def test_identity(self, col_type="identity"):
self.check_with_place("collective_identity_op.py", col_type)
if __name__ == '__main__':
unittest.main()
......@@ -274,6 +274,11 @@ class TestDistBase(unittest.TestCase):
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "identity":
need_result1 = input1
need_result2 = input2
self.assertTrue(np.allclose(tr0_out, need_result1, rtol=0, atol=0))
self.assertTrue(np.allclose(tr1_out, need_result2, rtol=0, atol=0))
elif col_type == "reduce_slicegather":
slicesize = input1.shape[0] // 2
tmp10 = input1[0:slicesize]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册