未验证 提交 0f353ab4 编写于 作者: Q Qiao Longfei 提交者: GitHub

cpu gpu transform function (#7191)

* add rename guard

* add device_data_transform

* add device_data_transform_test

* modify GetExpectedKernelType

* update operator.run

* support test test_label_semantic_roles

* optimize code

* optimize code

* rename GetActualKernelType to GetExpectedKernelType

* fix chunk_eval_op and device_data_transform_test

* add is_same_place to place

* optimize code, refine rename_guard

* refine rename guard, add GetKernelTypeForVar

* optimize code

* add some log

* rename guard

* use sub scope to create var

* fix compile

* add IsInitialized for Tensor

* add VarIsTensor

* fix op_registry_test

* test

* tmp disable priority

* restore switch_kernel.md

* code clean
上级 8814bec0
...@@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) ...@@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto) cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
...@@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat ...@@ -77,3 +79,6 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operat
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
DEPS operator op_registry init math_function)
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include <functional> #include <functional>
#include "paddle/framework/data_transform.h" #include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() { ...@@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
return data_transform_map; return data_transform_map;
} }
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else {
PADDLE_THROW("unknown var type");
}
}
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(), auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain); DataLayout::kNHWC, LibraryType::kPlain);
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/op_kernel_type.h" #include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
...@@ -49,6 +50,13 @@ struct KernelTypePairHash { ...@@ -49,6 +50,13 @@ struct KernelTypePairHash {
} }
}; };
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
template <typename InType, typename OutType> template <typename InType, typename OutType>
struct CastDataTypeFunctor { struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const { HOSTDEVICE inline OutType operator()(InType in) const {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/device_data_transform.h"
namespace paddle {
namespace framework {
static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
CopyFrom(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
AddComment("This is test op");
}
};
class TestOpWithKernel : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
};
template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;
const Tensor* input = ctx.Input<Tensor>("input");
std::cout << "input place:" << input->place() << std::endl;
auto* output = ctx.Output<framework::LoDTensor>("output");
output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace());
operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
input, input, output, ctx.template device_context<DeviceContext>(),
AddFunctor<T>());
functor.Run();
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
test_op, paddle::framework::TestOpWithKernel,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
// create an op to run on CPU
paddle::framework::proto::OpDesc cpu_op_desc;
cpu_op_desc.set_type("test_op");
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());
auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
// prepare input
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float>(i);
}
// get output
auto* output = scope.Var("OUT1");
cpu_op->Run(scope, cpu_place);
auto* output_ptr = output->Get<LoDTensor>().data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
}
// create an op to run on GPU
paddle::framework::proto::OpDesc gpu_op_desc;
gpu_op_desc.set_type("test_op");
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());
auto attr = gpu_op_desc.mutable_attrs()->Add();
attr->set_name("use_gpu");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);
auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);
paddle::platform::CUDAPlace cuda_place(0);
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
auto dev_ctx = pool.Get(cuda_place);
paddle::framework::Tensor output_tensor;
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor);
dev_ctx->Wait();
float* output2_ptr = output_tensor.data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
}
}
...@@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -218,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(InferShapeContext* ctx) const override {} void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
} }
...@@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel { ...@@ -282,16 +282,11 @@ class OpWithMultiKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(InferShapeContext* ctx) const override {} void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0), return framework::OpKernelType(
kernel.data_layout_, proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
}; };
...@@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { ...@@ -371,6 +366,7 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
op_desc.set_type("op_with_multi_kernel"); op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// TODO(qiao) add priority back
// use all available kernels // use all available kernels
paddle::framework::UseALL(); paddle::framework::UseALL();
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
...@@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) { ...@@ -380,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::UseCPU(); paddle::framework::UseCPU();
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -9); EXPECT_EQ(op_test_value, -20);
// add cuda kernels // add cuda kernels
paddle::framework::UseCUDA(); paddle::framework::UseCUDA();
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10); EXPECT_EQ(op_test_value, -30);
// use cudnn kernel // use cudnn kernel
paddle::framework::UseCUDNN(); paddle::framework::UseCUDNN();
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20); EXPECT_EQ(op_test_value, -40);
} }
...@@ -14,11 +14,10 @@ limitations under the License. */ ...@@ -14,11 +14,10 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h" #include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h" #include "paddle/framework/var_type.h"
...@@ -243,6 +242,10 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -243,6 +242,10 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
static const Tensor* GetTensorFromVar(const Variable* var) { static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr; const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
...@@ -453,30 +456,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -453,30 +456,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
const platform::DeviceContext* GetDeviceContext(
const framework::OpKernelType& kernel) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
return pool.Get(kernel.place_);
}
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
...@@ -492,94 +471,43 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -492,94 +471,43 @@ void OperatorWithKernel::Run(const Scope& scope,
"There are no kernels which are registered in the %s operator.", type_); "There are no kernels which are registered in the %s operator.", type_);
} }
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx); ExecutionContext ctx(*this, scope, *dev_ctx);
auto actual_kernel_key = GetActualKernelType(ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key); Scope& new_scope = scope.NewScope();
if (actual_kernel_key == expected_kernel_key) { for (auto& var_name_item : this->Inputs()) {
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_, for (auto& var_name : var_name_item.second) {
"Currently, model parallelism is only supported between " auto* var = scope.FindVar(var_name);
"CPU and other devices. For example, multi-GPU model " if (var && VarIsTensor(var)) {
"parallelism will failed."); auto* tensor_in = GetTensorFromVar(var);
} else { if (tensor_in->IsInitialized()) {
// find the best key candidate auto kernel_type_for_var = this->GetKernelTypeForVar(
const DataTransformFnMap& trans_map = DataTransformFnMap::Instance(); var_name_item.first, *tensor_in, expected_kernel_key);
for (auto& candidate : kKernelPriority) { if (kernel_type_for_var != expected_kernel_key) {
auto candidate_key = auto out_var_names = OutputVars(true);
OpKernelType(actual_kernel_key.data_type_, std::get<0>(candidate), if (std::find(out_var_names.begin(), out_var_names.end(),
actual_kernel_key.data_layout_, std::get<1>(candidate)); var_name) != out_var_names.end()) {
PADDLE_THROW(
auto candidate_pair = std::make_pair(actual_kernel_key, candidate_key); "var %s is both input and output, "
if ((actual_kernel_key == candidate_key) || "does not support transform",
(kernels.count(candidate_key) && var_name);
trans_map.GetNullable(candidate_pair))) { }
expected_kernel_key = candidate_key; VLOG(3) << "need to do transform for var " << var_name;
break; auto* trans_var = new_scope.Var(var_name);
} auto* out = DataTransform(expected_kernel_key, kernel_type_for_var,
} *tensor_in);
CopyVariableWithTensor(*var, *out, *trans_var);
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); }
const DataTransformFn* trans_fun = trans_map.GetNullable(kernel_pair);
if (trans_fun) {
auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
// filter vars that has been transformed
std::vector<std::string> need_trans;
for (auto var_name : input_vars) {
auto var_name_trans =
var_name + framework::KernelTypeToString(expected_kernel_key);
if (!scope.FindVar(var_name_trans)) {
const_cast<Scope&>(scope).Var(var_name_trans);
need_trans.push_back(var_name);
}
}
if (!need_trans.empty()) {
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx, kernel_pair, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key)));
} }
// Wait for data transform finishing
trans_dev_ctx->Wait();
} }
} }
} }
VLOG(10) << "Actual kernel: " << actual_kernel_key OpKernelMap& kernels = kernels_iter->second;
<< "Expected kernel: " << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) { kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
auto* expected_dev_ctx = GetDeviceContext(expected_kernel_key);
ExecutionContext expected_ctx(*this, scope, *expected_dev_ctx);
kernel_iter->second->Compute(expected_ctx);
}
OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
...@@ -611,5 +539,16 @@ proto::DataType OperatorWithKernel::IndicateDataType( ...@@ -611,5 +539,16 @@ proto::DataType OperatorWithKernel::IndicateDataType(
return static_cast<proto::DataType>(data_type); return static_cast<proto::DataType>(data_type);
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const {
return OpKernelType(expected_kernel_type.data_type_, tensor.place());
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -408,9 +408,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -408,9 +408,10 @@ class OperatorWithKernel : public OperatorBase {
} }
protected: protected:
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType( virtual OpKernelType GetKernelTypeForVar(
const OpKernelType& actual_kernel_type) const; const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const;
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
......
...@@ -114,7 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -114,7 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override { OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
} }
}; };
......
...@@ -109,6 +109,7 @@ std::string Scope::Rename(const std::string& origin_name) const { ...@@ -109,6 +109,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Rename(origin_name, var_name); Rename(origin_name, var_name);
return var_name; return var_name;
} }
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) return it->second; if (it != vars_.end()) return it->second;
......
...@@ -75,9 +75,9 @@ class Scope { ...@@ -75,9 +75,9 @@ class Scope {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
private:
Variable* FindVarLocally(const std::string& name) const; Variable* FindVarLocally(const std::string& name) const;
private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
......
...@@ -55,6 +55,8 @@ class Tensor { ...@@ -55,6 +55,8 @@ class Tensor {
template <typename T> template <typename T>
inline const T* data() const; inline const T* data() const;
inline bool IsInitialized() const;
inline void switch_place(platform::Place new_place); inline void switch_place(platform::Place new_place);
/** /**
......
...@@ -84,6 +84,8 @@ inline const T* Tensor::data() const { ...@@ -84,6 +84,8 @@ inline const T* Tensor::data() const {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
} }
inline bool Tensor::IsInitialized() const { return holder_ != nullptr; }
template <typename T> template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
......
...@@ -32,6 +32,8 @@ class Variable { ...@@ -32,6 +32,8 @@ class Variable {
return *static_cast<const T*>(holder_->Ptr()); return *static_cast<const T*>(holder_->Ptr());
} }
bool IsInitialized() const { return holder_ != nullptr; }
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (!IsType<T>()) { if (!IsType<T>()) {
......
...@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
...@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
...@@ -306,7 +306,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -306,7 +306,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y")); const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) { if (var == nullptr) {
......
...@@ -55,10 +55,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -55,10 +55,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32, return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context()); platform::CPUPlace());
} }
}; };
......
...@@ -145,6 +145,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> { ...@@ -145,6 +145,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
context.Attr<std::vector<int>>("excluded_chunk_types").end()); context.Attr<std::vector<int>>("excluded_chunk_types").end());
auto* inference = context.Input<LoDTensor>("Inference"); auto* inference = context.Input<LoDTensor>("Inference");
auto place = inference->place();
auto* label = context.Input<LoDTensor>("Label"); auto* label = context.Input<LoDTensor>("Label");
auto* precision = context.Output<Tensor>("Precision"); auto* precision = context.Output<Tensor>("Precision");
auto* recall = context.Output<Tensor>("Recall"); auto* recall = context.Output<Tensor>("Recall");
...@@ -155,15 +156,15 @@ class ChunkEvalKernel : public framework::OpKernel<T> { ...@@ -155,15 +156,15 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
const int64_t* inference_data = inference->data<int64_t>(); const int64_t* inference_data = inference->data<int64_t>();
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
T* precision_data = precision->mutable_data<T>(context.GetPlace()); T* precision_data = precision->mutable_data<T>(place);
T* racall_data = recall->mutable_data<T>(context.GetPlace()); T* racall_data = recall->mutable_data<T>(place);
T* f1_data = f1->mutable_data<T>(context.GetPlace()); T* f1_data = f1->mutable_data<T>(place);
int64_t* num_infer_chunks_data = int64_t* num_infer_chunks_data =
num_infer_chunks->mutable_data<int64_t>(context.GetPlace()); num_infer_chunks->mutable_data<int64_t>(place);
int64_t* num_label_chunks_data = int64_t* num_label_chunks_data =
num_label_chunks->mutable_data<int64_t>(context.GetPlace()); num_label_chunks->mutable_data<int64_t>(place);
int64_t* num_correct_chunks_data = int64_t* num_correct_chunks_data =
num_correct_chunks->mutable_data<int64_t>(context.GetPlace()); num_correct_chunks->mutable_data<int64_t>(place);
*num_infer_chunks_data = 0; *num_infer_chunks_data = 0;
*num_label_chunks_data = 0; *num_label_chunks_data = 0;
*num_correct_chunks_data = 0; *num_correct_chunks_data = 0;
......
...@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place // CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt; return kt;
......
...@@ -62,25 +62,12 @@ class ConvOp : public framework::OperatorWithKernel { ...@@ -62,25 +62,12 @@ class ConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -120,17 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -120,17 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context()); platform::CPUPlace());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -28,9 +28,6 @@ template <typename DeviceContext, typename T> ...@@ -28,9 +28,6 @@ template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> { class CRFDecodingOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"The crf_decoding operator can only run on CPU.");
auto* emission_weights = ctx.Input<LoDTensor>("Emission"); auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* transition_weights = ctx.Input<Tensor>("Transition"); auto* transition_weights = ctx.Input<Tensor>("Transition");
auto* label = ctx.Input<LoDTensor>("Label"); auto* label = ctx.Input<LoDTensor>("Label");
......
...@@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of computation kernel of cross_entropy // Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
...@@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of computation kernel of cross_entropy // Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
...@@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
...@@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
...@@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
...@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
...@@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf // Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission". // is determined by its input "Emission".
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()), framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
...@@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected: protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad // Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood. // operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
......
...@@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
...@@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel { ...@@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
...@@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place // LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt; return kt;
......
...@@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()), framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
...@@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()), framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......
...@@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
...@@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......
...@@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()), framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
...@@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()), framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......
...@@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
...@@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......
...@@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
...@@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
...@@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()), framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
......
...@@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()), framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
......
...@@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
...@@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
...@@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()), framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
...@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()), framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......
...@@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
...@@ -48,7 +48,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
...@@ -69,7 +69,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -69,7 +69,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
...@@ -118,7 +118,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -118,7 +118,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()), framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
...@@ -159,7 +159,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -159,7 +159,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
......
...@@ -53,7 +53,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -53,7 +53,7 @@ class SumOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X"); auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
......
...@@ -63,7 +63,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")), static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
...@@ -71,7 +71,7 @@ int OutputSize(int input_size, int ksize, int padding, int stride) { ...@@ -71,7 +71,7 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
class UnpoolOp : public framework::OperatorWithKernel { class UnpoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
...@@ -110,7 +110,7 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -110,7 +110,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
class UnpoolOpGrad : public framework::OperatorWithKernel { class UnpoolOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetActualKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
...@@ -51,6 +51,18 @@ bool places_are_same_class(const Place &p1, const Place &p2) { ...@@ -51,6 +51,18 @@ bool places_are_same_class(const Place &p1, const Place &p2) {
return p1.which() == p2.which(); return p1.which() == p2.which();
} }
bool is_same_place(const Place &p1, const Place &p2) {
if (places_are_same_class(p1, p2)) {
if (is_cpu_place(p1)) {
return true;
} else {
return boost::get<CUDAPlace>(p1) == boost::get<CUDAPlace>(p2);
}
} else {
return false;
}
}
std::ostream &operator<<(std::ostream &os, const Place &p) { std::ostream &operator<<(std::ostream &os, const Place &p) {
detail::PlacePrinter printer(os); detail::PlacePrinter printer(os);
boost::apply_visitor(printer, p); boost::apply_visitor(printer, p);
......
...@@ -61,6 +61,7 @@ const CPUPlace default_cpu(); ...@@ -61,6 +61,7 @@ const CPUPlace default_cpu();
bool is_gpu_place(const Place &); bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &); bool is_cpu_place(const Place &);
bool places_are_same_class(const Place &, const Place &); bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &); std::ostream &operator<<(std::ostream &, const Place &);
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05 import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import time
word_dict, verb_dict, label_dict = conll05.get_dict() word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict) word_dict_len = len(word_dict)
...@@ -160,7 +161,8 @@ def main(): ...@@ -160,7 +161,8 @@ def main():
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192), paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
place = fluid.CPUPlace() #place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
...@@ -174,6 +176,7 @@ def main(): ...@@ -174,6 +176,7 @@ def main():
embedding_param.set( embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place) load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)
start_time = time.time()
batch_id = 0 batch_id = 0
for pass_id in xrange(PASS_NUM): for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe) chunk_evaluator.reset(exe)
...@@ -191,6 +194,9 @@ def main(): ...@@ -191,6 +194,9 @@ def main():
f1_score) + " pass_precision:" + str( f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) pass_precision) + " pass_recall:" + str(pass_recall)
+ " pass_f1_score:" + str(pass_f1_score)) + " pass_f1_score:" + str(pass_f1_score))
if batch_id != 0:
print("second per batch: " + str((time.time() - start_time)
/ batch_id))
# exit early for CI # exit early for CI
exit(0) exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册