提交 29035265 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/python_api_design

# Design Doc: ProgramDesc
The basic structure of a PaddlePaddle program is some nested blocks, as a C++ or Java program.
As described in [graph.md](./graph.md), the first five lines of the following PaddlePaddle program
```python
x = layer.data("images")
l = layer.data("label")
y = layer.fc(x)
cost = layer.mse(y, l)
optimize(cost)
train(cost, reader=mnist.train())
```
generates, or compiles, a PaddelPaddle program, which is represented by the following protobuf message:
```protobuf
message ProgramDesc {
repeated BlockDesc blocks = 1;
}
message BlockDesc {
required int32 parent = 1;
repeated VarDesc vars = 2;
repeated OpDesc ops = 3;
}
message OpDesc {
AttrDesc attrs = 1;
...
}
message AttrDesc {
required AttrType type = 1;
// index into ProgramDesc::blocks when type==BLOCK
optional int32 block = 2;
...
}
```
When each of the first five lines runs, related Python function, e.g., `layer.fc`, calls C++ InferShape functions. This InferShape function needs to access the properties of VarDesc's accessed by the current OpDesc. These VarDesc's might not be defined in the current block, but in some ancestor blocks. This requires that we can trace the parent of a block.
A nested block is often an attribute of an operator, most likely, an IfElseOp or a WhileOp. In above solution, all blocks are in `ProgramDesc::blocks`, this implicitly assigns a zero-based ID to each block -- the index of the block in `ProgramDesc::blocks`. So that `AttrDesc::block` could be an integer block ID.
With this design, the InferShape function should take the following parameters:
```c++
void InferShape(int current_block,
int current_operator,
ProgramDesc* program // might change VarDesc values.
) {
...
}
```
where
- `current_block` indices into `ProgramDesc::blocks`,
- `current_operator` indices into `BlockDesc::ops`.
......@@ -19,12 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_proto_maker.h"
#include "gtest/gtest.h"
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
......
......@@ -228,43 +228,5 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
} // namespace framework
} // namespace paddle
......@@ -167,71 +167,6 @@ class NOP : public OperatorBase {
}
};
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......
......@@ -264,37 +264,3 @@ TEST(Operator, Clone) {
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
\ No newline at end of file
......@@ -96,3 +96,4 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace detail {
template <typename T, int Rank>
struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<1> src_stride, framework::Dim<1> dst_dim,
framework::Dim<1> dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
} else {
#ifndef PADDLE_ONLY_CPU
auto& gpu_place = boost::get<platform::GPUPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
};
template <typename T, int Rank>
struct StridedMemcpyFunctor {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<Rank> src_stride, framework::Dim<Rank> dst_dim,
framework::Dim<Rank> dst_stride, T* dst) const {
for (int64_t i = 0; i < dst_dim.head; ++i) {
StridedMemcpyFunctor<T, Rank - 1> func;
func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst);
src += src_stride.head;
dst += dst_stride.head;
}
}
};
template <typename T>
struct StridedCopyDimVisitor : public boost::static_visitor<void> {
StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride,
const framework::DDim& dst_stride, T* dst)
: dev_ctx_(dev_ctx),
src_(src),
src_stride_(src_stride),
dst_stride_(dst_stride),
dst_(dst) {}
template <typename Dim>
void operator()(Dim dst_dim) const {
Dim src_stride = boost::get<Dim>(src_stride_);
Dim dst_stride = boost::get<Dim>(dst_stride_);
constexpr int dim = Dim::dimensions;
StridedMemcpyFunctor<T, dim> functor;
functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_);
}
const platform::DeviceContext& dev_ctx_;
const T* src_;
const framework::DDim& src_stride_;
const framework::DDim& dst_stride_;
T* dst_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -96,7 +96,7 @@ class PReluGradKernel : public framework::OpKernel {
trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
// TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/detail/strided_memcpy.h"
namespace paddle {
namespace operators {
// Strided memory copy from src to dst.
//
// The src and dst should be both on dev_ctx.GetPlace(), otherwise, there will
// be a segment fault.
//
// The stride of an array (also referred to as increment, pitch or step size) is
// the number of locations in memory between beginnings of successive array
// elements
//
// For example, for tensor like [1, 3, 300, 300]. If there is no padding, the
// stride is [270000, 90000, 300, 1].
//
// NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke
// `dev_ctx.Wait()`.
template <typename T>
inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride,
const framework::DDim& dst_dim,
const framework::DDim& dst_stride, T* dst) {
using namespace detail;
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/strided_memcpy.h"
#include "gtest/gtest.h"
#include "paddle/memory/memory.h"
namespace paddle {
namespace operators {
TEST(StridedMemcpy, CPUCrop) {
// clang-format off
int src[] = {
0, 1, 2, 0, 0,
0, 3, 4, 0, 0,
0, 0, 0, 0, 0,
};
// clang-format on
framework::DDim src_stride({5, 1});
int dst[4];
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({2, 1});
platform::CPUDeviceContext ctx;
StridedMemcpy<int>(ctx, src + 1, src_stride, dst_dim, dst_stride, dst);
ASSERT_EQ(1, dst[0]);
ASSERT_EQ(2, dst[1]);
ASSERT_EQ(3, dst[2]);
ASSERT_EQ(4, dst[3]);
}
TEST(StridedMemcpy, CPUConcat) {
// clang-format off
int src[] = {
1, 2,
3, 4
};
// clang-format on
int dst[8];
framework::DDim src_stride({2, 1});
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({4, 1});
platform::CPUDeviceContext ctx;
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst);
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst + 2);
// clang-format off
int expect_dst[] = {
1, 2, 1, 2,
3, 4, 3, 4
};
// clang-format on
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
ASSERT_EQ(expect_dst[i], dst[i]);
}
}
#ifndef PADDLE_ONLY_CPU
TEST(StridedMemcpy, GPUCrop) {
// clang-format off
int src[] = {
0, 1, 2, 0, 0,
0, 3, 4, 0, 0,
0, 0, 0, 0, 0,
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CPUPlace cpu;
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
framework::DDim src_stride({5, 1});
int dst[4];
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({2, 1});
platform::CUDADeviceContext ctx(gpu0);
StridedMemcpy<int>(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride,
gpu_dst);
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
ctx.Wait();
ASSERT_EQ(1, dst[0]);
ASSERT_EQ(2, dst[1]);
ASSERT_EQ(3, dst[2]);
ASSERT_EQ(4, dst[3]);
memory::Free(gpu0, gpu_dst);
memory::Free(gpu0, gpu_src);
}
TEST(StridedMemcpy, GPUConcat) {
// clang-format off
int src[] = {
1, 2,
3, 4
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CPUPlace cpu;
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
int dst[8];
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
framework::DDim src_stride({2, 1});
framework::DDim dst_dim({2, 2});
framework::DDim dst_stride({4, 1});
platform::CUDADeviceContext ctx(gpu0);
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst);
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride,
gpu_dst + 2);
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
ctx.Wait();
// clang-format off
int expect_dst[] = {
1, 2, 1, 2,
3, 4, 3, 4
};
// clang-format on
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
ASSERT_EQ(expect_dst[i], dst[i]);
}
memory::Free(gpu0, gpu_dst);
memory::Free(gpu0, gpu_src);
}
#endif
} // namespace operators
} // namespace paddle
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册