未验证 提交 0f353ab4 编写于 作者: Q Qiao Longfei 提交者: GitHub

cpu gpu transform function (#7191)

* add rename guard

* add device_data_transform

* add device_data_transform_test

* modify GetExpectedKernelType

* update operator.run

* support test test_label_semantic_roles

* optimize code

* optimize code

* rename GetActualKernelType to GetExpectedKernelType

* fix chunk_eval_op and device_data_transform_test

* add is_same_place to place

* optimize code, refine rename_guard

* refine rename guard, add GetKernelTypeForVar

* optimize code

* add some log

* rename guard

* use sub scope to create var

* fix compile

* add IsInitialized for Tensor

* add VarIsTensor

* fix op_registry_test

* test

* tmp disable priority

* restore switch_kernel.md

* code clean
上级 8814bec0
......@@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
......@@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
DEPS operator op_registry init math_function)
......@@ -14,7 +14,9 @@ limitations under the License. */
#include <functional>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......@@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
return data_transform_map;
}
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else {
PADDLE_THROW("unknown var type");
}
}
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/math/math_function.h"
......@@ -49,6 +50,13 @@ struct KernelTypePairHash {
}
};
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/device_data_transform.h"
namespace paddle {
namespace framework {
static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
CopyFrom(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
AddComment("This is test op");
}
};
class TestOpWithKernel : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
};
template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;
const Tensor* input = ctx.Input<Tensor>("input");
std::cout << "input place:" << input->place() << std::endl;
auto* output = ctx.Output<framework::LoDTensor>("output");
output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace());
operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
input, input, output, ctx.template device_context<DeviceContext>(),
AddFunctor<T>());
functor.Run();
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
test_op, paddle::framework::TestOpWithKernel,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
// create an op to run on CPU
paddle::framework::proto::OpDesc cpu_op_desc;
cpu_op_desc.set_type("test_op");
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());
auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
// prepare input
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float>(i);
}
// get output
auto* output = scope.Var("OUT1");
cpu_op->Run(scope, cpu_place);
auto* output_ptr = output->Get<LoDTensor>().data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
}
// create an op to run on GPU
paddle::framework::proto::OpDesc gpu_op_desc;
gpu_op_desc.set_type("test_op");
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());
auto attr = gpu_op_desc.mutable_attrs()->Add();
attr->set_name("use_gpu");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);
auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);
paddle::platform::CUDAPlace cuda_place(0);
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
auto dev_ctx = pool.Get(cuda_place);
paddle::framework::Tensor output_tensor;
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor);
dev_ctx->Wait();
float* output2_ptr = output_tensor.data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
}
}
......@@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
......@@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};
......@@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// TODO(qiao) add priority back
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
......@@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU();
op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -9);
EXPECT_EQ(op_test_value, -20);
// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
EXPECT_EQ(op_test_value, -30);
// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20);
EXPECT_EQ(op_test_value, -40);
}
......@@ -14,11 +14,10 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h"
......@@ -243,6 +242,10 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
......@@ -453,30 +456,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
const platform::DeviceContext* GetDeviceContext(
const framework::OpKernelType& kernel) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
return pool.Get(kernel.place_);
}
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......@@ -492,94 +471,43 @@ void OperatorWithKernel::Run(const Scope& scope,
"There are no kernels which are registered in the %s operator.", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
if (actual_kernel_key == expected_kernel_key) {
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
"Currently, model parallelism is only supported between "
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else {
// find the best key candidate
const DataTransformFnMap& trans_map = DataTransformFnMap::Instance();
for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(actual_kernel_key.data_type_, std::get<0>(candidate),
actual_kernel_key.data_layout_, std::get<1>(candidate));
auto candidate_pair = std::make_pair(actual_kernel_key, candidate_key);
if ((actual_kernel_key == candidate_key) ||
(kernels.count(candidate_key) &&
trans_map.GetNullable(candidate_pair))) {
expected_kernel_key = candidate_key;
break;
}
}
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun = trans_map.GetNullable(kernel_pair);
if (trans_fun) {
auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
// filter vars that has been transformed
std::vector<std::string> need_trans;
for (auto var_name : input_vars) {
auto var_name_trans =
var_name + framework::KernelTypeToString(expected_kernel_key);
if (!scope.FindVar(var_name_trans)) {
const_cast<Scope&>(scope).Var(var_name_trans);
need_trans.push_back(var_name);
}
}
if (!need_trans.empty()) {
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx, kernel_pair, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key)));
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
Scope& new_scope = scope.NewScope();
for (auto& var_name_item : this->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope.FindVar(var_name);
if (var && VarIsTensor(var)) {
auto* tensor_in = GetTensorFromVar(var);
if (tensor_in->IsInitialized()) {
auto kernel_type_for_var = this->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
if (kernel_type_for_var != expected_kernel_key) {
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(),
var_name) != out_var_names.end()) {
PADDLE_THROW(
"var %s is both input and output, "
"does not support transform",
var_name);
}
VLOG(3) << "need to do transform for var " << var_name;
auto* trans_var = new_scope.Var(var_name);
auto* out = DataTransform(expected_kernel_key, kernel_type_for_var,
*tensor_in);
CopyVariableWithTensor(*var, *out, *trans_var);
}
}
// Wait for data transform finishing
trans_dev_ctx->Wait();
}
}
}
VLOG(10) << "Actual kernel: " << actual_kernel_key
<< "Expected kernel: " << expected_kernel_key;
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
auto* expected_dev_ctx = GetDeviceContext(expected_kernel_key);
ExecutionContext expected_ctx(*this, scope, *expected_dev_ctx);
kernel_iter->second->Compute(expected_ctx);
}
OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
}
proto::DataType OperatorWithKernel::IndicateDataType(
......@@ -611,5 +539,16 @@ proto::DataType OperatorWithKernel::IndicateDataType(
return static_cast<proto::DataType>(data_type);
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const {
return OpKernelType(expected_kernel_type.data_type_, tensor.place());
}
} // namespace framework
} // namespace paddle
......@@ -408,9 +408,10 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
......
......@@ -114,7 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
......
......@@ -109,6 +109,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Rename(origin_name, var_name);
return var_name;
}
Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second;
......
......@@ -75,9 +75,9 @@ class Scope {
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
private:
Variable* FindVarLocally(const std::string& name) const;
private:
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {}
......
......@@ -55,6 +55,8 @@ class Tensor {
template <typename T>
inline const T* data() const;
inline bool IsInitialized() const;
inline void switch_place(platform::Place new_place);
/**
......
......@@ -84,6 +84,8 @@ inline const T* Tensor::data() const {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
inline bool Tensor::IsInitialized() const { return holder_ != nullptr; }
template <typename T>
inline T* Tensor::data() {
check_memory_size();
......
......@@ -32,6 +32,8 @@ class Variable {
return *static_cast<const T*>(holder_->Ptr());
}
bool IsInitialized() const { return holder_ != nullptr; }
template <typename T>
T* GetMutable() {
if (!IsType<T>()) {
......
......@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -306,7 +306,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......
......@@ -55,10 +55,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
platform::CPUPlace());
}
};
......
......@@ -145,6 +145,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
context.Attr<std::vector<int>>("excluded_chunk_types").end());
auto* inference = context.Input<LoDTensor>("Inference");
auto place = inference->place();
auto* label = context.Input<LoDTensor>("Label");
auto* precision = context.Output<Tensor>("Precision");
auto* recall = context.Output<Tensor>("Recall");
......@@ -155,15 +156,15 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
const int64_t* inference_data = inference->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
T* precision_data = precision->mutable_data<T>(context.GetPlace());
T* racall_data = recall->mutable_data<T>(context.GetPlace());
T* f1_data = f1->mutable_data<T>(context.GetPlace());
T* precision_data = precision->mutable_data<T>(place);
T* racall_data = recall->mutable_data<T>(place);
T* f1_data = f1->mutable_data<T>(place);
int64_t* num_infer_chunks_data =
num_infer_chunks->mutable_data<int64_t>(context.GetPlace());
num_infer_chunks->mutable_data<int64_t>(place);
int64_t* num_label_chunks_data =
num_label_chunks->mutable_data<int64_t>(context.GetPlace());
num_label_chunks->mutable_data<int64_t>(place);
int64_t* num_correct_chunks_data =
num_correct_chunks->mutable_data<int64_t>(context.GetPlace());
num_correct_chunks->mutable_data<int64_t>(place);
*num_infer_chunks_data = 0;
*num_label_chunks_data = 0;
*num_correct_chunks_data = 0;
......
......@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -62,25 +62,12 @@ class ConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
};
class ConvOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
};
template <typename DeviceContext, typename T>
......
......@@ -120,17 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
platform::CPUPlace());
}
};
} // namespace operators
......
......@@ -28,9 +28,6 @@ template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"The crf_decoding operator can only run on CPU.");
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* transition_weights = ctx.Input<Tensor>("Transition");
auto* label = ctx.Input<LoDTensor>("Label");
......
......@@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
......@@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......@@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......
......@@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......@@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......
......@@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......@@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......
......@@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......@@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......
......@@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
......
......@@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
......
......@@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......
......@@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -48,7 +48,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -69,7 +69,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -118,7 +118,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
......@@ -159,7 +159,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -53,7 +53,7 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
......
......@@ -63,7 +63,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -71,7 +71,7 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
class UnpoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -110,7 +110,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -51,6 +51,18 @@ bool places_are_same_class(const Place &p1, const Place &p2) {
return p1.which() == p2.which();
}
bool is_same_place(const Place &p1, const Place &p2) {
if (places_are_same_class(p1, p2)) {
if (is_cpu_place(p1)) {
return true;
} else {
return boost::get<CUDAPlace>(p1) == boost::get<CUDAPlace>(p2);
}
} else {
return false;
}
}
std::ostream &operator<<(std::ostream &os, const Place &p) {
detail::PlacePrinter printer(os);
boost::apply_visitor(printer, p);
......
......@@ -61,6 +61,7 @@ const CPUPlace default_cpu();
bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &);
bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &);
......
......@@ -4,6 +4,7 @@ import numpy as np
import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid
import time
word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict)
......@@ -160,7 +161,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
#place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
......@@ -174,6 +176,7 @@ def main():
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)
start_time = time.time()
batch_id = 0
for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
......@@ -191,6 +194,9 @@ def main():
f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall)
+ " pass_f1_score:" + str(pass_f1_score))
if batch_id != 0:
print("second per batch: " + str((time.time() - start_time)
/ batch_id))
# exit early for CI
exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册