diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 359f58328a86c10896f5a852c3683e60841f1eab..5899a14f503fffe603803bfe56533aa40425a252 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -463,7 +463,8 @@ class OperatorWithKernel : public OperatorBase { std::vector* GetKernelConfig(const OpKernelType& key) const; - protected: + // change this to public so that in dygraph mode we can call it to check if we + // need transform data virtual OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, const OpKernelType& expected_kernel_type) const; diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 1a1deef963542b1d16978d314d126585af5c07b8..5ba7c32d01fce2c75007ac1026278f4a7689ef55 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags) -cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits) +cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows var_type_traits layer) cc_library(tracer SRCS tracer.cc DEPS layer engine) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index bed49f0d12c2da1fa74db93f12f4e88873481e18..8a5db26d7d7f158c3f436e3ad339dd29b8132735 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -264,11 +264,11 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { VLOG(3) << "Running Op " << Type(); VLOG(5) << LayerDebugString(Type(), ins, outs); auto runtime_ctx = PrepareRuntimeContext(ins, outs); - auto runtime_place = PreparedOp::GetExpectedPlace(place(), ins); - auto prepared_op = - PreparedOp::Prepare(runtime_ctx, *op_kernel, runtime_place); + VLOG(6) << "start preparing op: " << Type(); + auto prepared_op = PreparedOp::Prepare(runtime_ctx, *op_kernel, place(), ins); + VLOG(6) << "finish preparing op: " << Type(); prepared_op.Run(); VLOG(4) << LayerDebugString(Type(), ins, outs); diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index e5b48fa2e20955b57d7bba957d8340506d422b73..6f8ee92bdfc7ba9c68f8e567f3f1bad0a2cbabeb 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -28,22 +28,34 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { } } -platform::Place PreparedOp::GetExpectedPlace(const platform::Place& place, - const NameVarBaseMap& ins) { - bool found = false; - for (auto& name_pair : ins) { - for (auto& var_base : name_pair.second) { +void PreparedOp::PrepareData( + const platform::Place& place, const NameVarBaseMap& ins, + const framework::OperatorWithKernel& op, + const framework::OpKernelType& expected_kernel_key) { + for (const auto& name_pair : ins) { + for (const auto& var_base : name_pair.second) { const auto* tensor = GetTensorFromVar(var_base->Var()); if (tensor && tensor->IsInitialized()) { auto tmp_place = tensor->place(); - PADDLE_ENFORCE_EQ(!found || tmp_place == place, true, - "Input variable should keep in the same place: %s, " - "but get place: %s of input %s instead", - place, tmp_place, name_pair.first); + // TODO(jiabin): Support transform data layout when we Verify it on more + // tests + if (!(tmp_place == place)) { + auto kernel_type_for_var = op.GetKernelTypeForVar( + name_pair.first, *tensor, expected_kernel_key); + if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { + continue; + } else { + VLOG(3) << "Transform Variable " << var_base->Name() << " from " + << kernel_type_for_var << " to " << expected_kernel_key; + framework::Tensor out; + TransformData(expected_kernel_key, kernel_type_for_var, *tensor, + &out); + SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); + } + } } } } - return place; } PreparedOp::PreparedOp(const framework::OperatorBase& op, @@ -59,8 +71,10 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp PreparedOp::Prepare(const framework::RuntimeContext& ctx, const framework::OperatorWithKernel& op, - const platform::Place& place) { - auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + platform::Place place, + const NameVarBaseMap& ins) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); // check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); @@ -86,6 +100,13 @@ PreparedOp PreparedOp::Prepare(const framework::RuntimeContext& ctx, } std::vector* kernel_configs = op.GetKernelConfig(expected_kernel_key); + + if (!(expected_kernel_key.place_ == place)) { + dev_ctx = pool.Get(expected_kernel_key.place_); + place = dev_ctx->GetPlace(); + } + + PrepareData(place, ins, op, expected_kernel_key); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); } @@ -93,6 +114,7 @@ void PreparedOp::Run() { // TODO(zjl): remove scope in dygraph framework::Scope scope; op_.RuntimeInferShape(scope, dev_ctx_->GetPlace(), ctx_); + VLOG(6) << "Finish Runtime infer shape"; func_(framework::ExecutionContext(op_, scope, *dev_ctx_, ctx_, kernel_configs_)); } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 4616a85674683695875ff932c11dd3adba384170..886311f8c82fce4b3b1cd46bbe2ac6e5f22c50e5 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -17,6 +17,8 @@ #include #include #include +#include "paddle/fluid/framework/data_transform.h" +#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" @@ -30,14 +32,16 @@ class PreparedOp { public: static PreparedOp Prepare(const framework::RuntimeContext& ctx, const framework::OperatorWithKernel& op, - const platform::Place& place); + platform::Place place, const NameVarBaseMap& ins); inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; } void Run(); - static platform::Place GetExpectedPlace(const platform::Place& place, - const NameVarBaseMap& ins); + static void PrepareData(const platform::Place& place, + const NameVarBaseMap& ins, + const framework::OperatorWithKernel& op, + const framework::OpKernelType& expected_kernel_key); private: PreparedOp(const framework::OperatorBase& op, diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index 25a038997fa85b0f181e3bd43dbd79a9cd9a9b25..f32f0a1726fc07bab5fdbb971fa258a97e3c8f7f 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -1,5 +1,5 @@ cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS gradient_accumulator memcpy) -cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op) -cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split) -cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op) +cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) +cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place) +cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op memcpy) diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 6c35248262ae43696990ae2b874e58ad81fb3c26..1a30868da041eb0c7dc2d7ed9308871f231f5ab9 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -111,7 +111,7 @@ TEST(test_prepare_op, test_prepare_op) { split_attr_map); framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = - PreparedOp::Prepare(ctx, op, place)); + PreparedOp::Prepare(ctx, op, place, ins)); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); @@ -123,8 +123,94 @@ TEST(test_prepare_op, test_get_tensor_from_var) { auto* ts = GetTensorFromVar(*vout_error->MutableVar()); ASSERT_TRUE(ts != nullptr); } +#if defined(PADDLE_WITH_CUDA) +TEST(test_prepare_op, test_prepare_data) { + std::shared_ptr vin( + new imperative::VarBase(false, "vin")); + std::shared_ptr vout( + new imperative::VarBase(false, "vout")); + + framework::OpDesc desc; + platform::CPUPlace cpu_place; + platform::CUDAPlace gpu_place(0); + std::vector src_data(10, 2.0); + std::vector dims = {2, 5}; + + // prepare an cpu only input + auto* vin_tensor = vin->MutableVar()->GetMutable(); + vin_tensor->Resize(framework::make_ddim(dims)); + auto* vin_mutable_tensor = vin_tensor->mutable_data(cpu_place); + paddle::memory::Copy(cpu_place, vin_mutable_tensor, cpu_place, + src_data.data(), sizeof(float) * src_data.size()); + var_pair x_pair = var_pair("X", vb_vector(1, vin)); + var_pair out_pair = var_pair("Out", vb_vector(1, vout)); + imperative::NameVarBaseMap ins = {x_pair}; + imperative::NameVarBaseMap outs = {out_pair}; + framework::AttributeMap assign_attr_map; + const auto& info = framework::OpInfoMap::Instance().Get("assign"); + framework::VariableNameMap var_in_map = + CreateVarNameMap(info, "assign", ins, true); + framework::VariableNameMap var_out_map = + CreateVarNameMap(info, "assign", outs, false); + framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map, + assign_attr_map); + framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); + + // test if it can be transformed to GPU place + PreparedOp prepared_op = PreparedOp::Prepare(ctx, assign_op, gpu_place, ins); + for (const auto& name_pair : ins) { + for (const auto& vb : name_pair.second) { + ASSERT_TRUE(platform::is_same_place( + vb->Var().Get().place(), gpu_place)); + } + } +} +#endif + +TEST(test_prepare_op, test_prepare_data_same_place) { + std::shared_ptr vin( + new imperative::VarBase(false, "vin")); + std::shared_ptr vout( + new imperative::VarBase(false, "vout")); + + framework::OpDesc desc; + platform::CPUPlace cpu_place; + std::vector src_data(10, 2.0); + std::vector dims = {2, 5}; + + // prepare an cpu only input + auto* vin_tensor = vin->MutableVar()->GetMutable(); + vin_tensor->Resize(framework::make_ddim(dims)); + auto* vin_mutable_tensor = vin_tensor->mutable_data(cpu_place); + paddle::memory::Copy(cpu_place, vin_mutable_tensor, cpu_place, + src_data.data(), sizeof(float) * src_data.size()); + + var_pair x_pair = var_pair("X", vb_vector(1, vin)); + var_pair out_pair = var_pair("Out", vb_vector(1, vout)); + imperative::NameVarBaseMap ins = {x_pair}; + imperative::NameVarBaseMap outs = {out_pair}; + framework::AttributeMap assign_attr_map; + const auto& info = framework::OpInfoMap::Instance().Get("assign"); + framework::VariableNameMap var_in_map = + CreateVarNameMap(info, "assign", ins, true); + framework::VariableNameMap var_out_map = + CreateVarNameMap(info, "assign", outs, false); + framework::OperatorWithKernel assign_op("assign", var_in_map, var_out_map, + assign_attr_map); + framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); + + // test if it never transfered on GPU place + PreparedOp prepared_op = PreparedOp::Prepare(ctx, assign_op, cpu_place, ins); + for (const auto& name_pair : ins) { + for (const auto& vb : name_pair.second) { + ASSERT_TRUE(platform::is_same_place( + vb->Var().Get().place(), cpu_place)); + } + } +} } // namespace imperative } // namespace paddle USE_OP(split); +USE_OP(assign); diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 3a544e5f502b1e635a5185a9bb9c86b181b8535d..1bd0e8bc9da4fe916baa62e31dde606813f535db 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -22,6 +22,7 @@ #include #include "gtest/gtest.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/memory/memcpy.h" namespace imperative = paddle::imperative; namespace platform = paddle::platform; @@ -142,6 +143,48 @@ TEST(test_tracer, test_track_backward_input) { mul_attr_map["use_mkldnn"] = false; ASSERT_ANY_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true)); } +#if defined(PADDLE_WITH_CUDA) +TEST(test_tracer, test_trace_op_with_multi_device_inputs) { + // Doing an mul + imperative::Tracer tracer; + std::shared_ptr x_in( + new imperative::VarBase(true, "x_in")); + std::shared_ptr y_in( + new imperative::VarBase(true, "y_in")); + std::shared_ptr vout( + new imperative::VarBase(true, "vout")); + platform::CPUPlace place; + platform::CUDAPlace gpu_place(0); + std::vector src_data(10, 2.0); + std::vector dims1 = {2, 5}; + std::vector dims2 = {5, 2}; + + auto* x_in_tensor = x_in->MutableVar()->GetMutable(); + auto* y_in_tensor = y_in->MutableVar()->GetMutable(); + x_in_tensor->Resize(framework::make_ddim(dims1)); + auto* mutable_x = x_in_tensor->mutable_data(place); + paddle::memory::Copy(place, mutable_x, place, src_data.data(), + sizeof(float) * src_data.size()); + y_in_tensor->Resize(framework::make_ddim(dims2)); + auto* mutable_y = y_in_tensor->mutable_data(gpu_place); + paddle::memory::Copy(gpu_place, mutable_y, place, src_data.data(), + sizeof(float) * src_data.size(), 0); + var_pair x_pair = var_pair("X", vb_vector(1, x_in)); + var_pair y_pair = var_pair("Y", vb_vector(1, y_in)); + var_pair out_pair = var_pair("Out", vb_vector(1, vout)); + imperative::NameVarBaseMap ins = {x_pair, y_pair}; + imperative::NameVarBaseMap outs = {out_pair}; + framework::AttributeMap mul_attr_map; + mul_attr_map["use_mkldnn"] = false; + tracer.TraceOp("mul", ins, outs, mul_attr_map, gpu_place, true); + framework::LoDTensor rlt; + framework::TensorCopySync(vout->Var().Get(), place, + &rlt); + for (size_t i = 0; i < rlt.numel(); i++) { + ASSERT_EQ(rlt.data()[i], 20.0); + } +} +#endif } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index e3cf2b32d797f812c237ab183fc2818382ecb1d0..ff40f2e8d88a53b70b1895d4aa095dd544d4cbed 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -45,11 +45,13 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, TraceBackward(op, framework::OpDesc(op->Type(), op->InputNameMap(), op->OutputNameMap(), op->Attrs()), ins, outs); + VLOG(6) << "Finish tracking Backward of op: " << type; } + VLOG(6) << "Finish tracing fwd op: " << type; } bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, - const NameVarBaseMap outs, + const NameVarBaseMap& outs, bool trace_backward) { // TODO(jiabin): Implement auto prune here return trace_backward; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f0a75d44731b20df36de3f93c1dce5a98ce6ae57..9c24b65ee1603d41cc038c28560358d7c3c27bb0 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -40,8 +40,8 @@ class Tracer { const NameVarBaseMap& outs, framework::AttributeMap attrs, const platform::Place& place, bool trace_bacward); - bool ComputeRequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap outs, - bool trace_backward); + bool ComputeRequiredGrad(const NameVarBaseMap& ins, + const NameVarBaseMap& outs, bool trace_backward); void TraceBackward(const std::shared_ptr& fwd_op, const framework::OpDesc& fwd_op_desc, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 4be43ae48073cc7ba3f413ee982654edbb6a7067..47779859a0badb0e26205b17665241a467012dcc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -223,6 +223,13 @@ class TestLayer(LayerTest): conv2d = nn.Conv2D('conv2d', num_filters=3, filter_size=[2, 2]) dy_ret = conv2d(base.to_variable(images)) + with self.dynamic_graph(): + images = np.ones([2, 3, 5, 5], dtype='float32') + conv2d = nn.Conv2D( + 'conv2d', num_filters=3, filter_size=[2, 2], bias_attr=False) + dy_ret = conv2d(base.to_variable(images)) + self.assertTrue(conv2d._bias_param is None) + self.assertTrue(np.allclose(static_ret, dy_ret.numpy())) self.assertTrue(np.allclose(static_ret, static_ret2)) @@ -994,6 +1001,105 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_ret, dy_ret.numpy())) + def test_compare(self): + value_a = np.arange(3) + value_b = np.arange(3) + # less than + with self.static_graph(): + a = layers.data(name='a', shape=[1], dtype='int64') + b = layers.data(name='b', shape=[1], dtype='int64') + cond = layers.less_than(x=a, y=b) + static_ret = self.get_static_graph_result( + feed={"a": value_a, + "b": value_b}, fetch_list=[cond])[0] + with self.dynamic_graph(): + da = base.to_variable(value_a) + db = base.to_variable(value_b) + dcond = layers.less_than(x=da, y=db) + + for i in range(len(static_ret)): + self.assertTrue(dcond.numpy()[i] == static_ret[i]) + + # less equal + with self.static_graph(): + a1 = layers.data(name='a1', shape=[1], dtype='int64') + b1 = layers.data(name='b1', shape=[1], dtype='int64') + cond1 = layers.less_equal(x=a1, y=b1) + static_ret1 = self.get_static_graph_result( + feed={"a1": value_a, + "b1": value_b}, fetch_list=[cond1])[0] + with self.dynamic_graph(): + da1 = base.to_variable(value_a) + db1 = base.to_variable(value_b) + dcond1 = layers.less_equal(x=da1, y=db1) + + for i in range(len(static_ret1)): + self.assertTrue(dcond1.numpy()[i] == static_ret1[i]) + + #greater than + with self.static_graph(): + a2 = layers.data(name='a2', shape=[1], dtype='int64') + b2 = layers.data(name='b2', shape=[1], dtype='int64') + cond2 = layers.greater_than(x=a2, y=b2) + static_ret2 = self.get_static_graph_result( + feed={"a2": value_a, + "b2": value_b}, fetch_list=[cond2])[0] + with self.dynamic_graph(): + da2 = base.to_variable(value_a) + db2 = base.to_variable(value_b) + dcond2 = layers.greater_than(x=da2, y=db2) + + for i in range(len(static_ret2)): + self.assertTrue(dcond2.numpy()[i] == static_ret2[i]) + + #greater equal + with self.static_graph(): + a3 = layers.data(name='a3', shape=[1], dtype='int64') + b3 = layers.data(name='b3', shape=[1], dtype='int64') + cond3 = layers.greater_equal(x=a3, y=b3) + static_ret3 = self.get_static_graph_result( + feed={"a3": value_a, + "b3": value_b}, fetch_list=[cond3])[0] + with self.dynamic_graph(): + da3 = base.to_variable(value_a) + db3 = base.to_variable(value_b) + dcond3 = layers.greater_equal(x=da3, y=db3) + + for i in range(len(static_ret3)): + self.assertTrue(dcond3.numpy()[i] == static_ret3[i]) + + # equal + with self.static_graph(): + a4 = layers.data(name='a4', shape=[1], dtype='int64') + b4 = layers.data(name='b4', shape=[1], dtype='int64') + cond4 = layers.equal(x=a4, y=b4) + static_ret4 = self.get_static_graph_result( + feed={"a4": value_a, + "b4": value_b}, fetch_list=[cond4])[0] + with self.dynamic_graph(): + da4 = base.to_variable(value_a) + db4 = base.to_variable(value_b) + dcond4 = layers.equal(x=da4, y=db4) + + for i in range(len(static_ret4)): + self.assertTrue(dcond4.numpy()[i] == static_ret4[i]) + + # not equal + with self.static_graph(): + a5 = layers.data(name='a5', shape=[1], dtype='int64') + b5 = layers.data(name='b5', shape=[1], dtype='int64') + cond5 = layers.equal(x=a5, y=b5) + static_ret5 = self.get_static_graph_result( + feed={"a5": value_a, + "b5": value_b}, fetch_list=[cond5])[0] + with self.dynamic_graph(): + da5 = base.to_variable(value_a) + db5 = base.to_variable(value_b) + dcond5 = layers.equal(x=da5, y=db5) + + for i in range(len(static_ret5)): + self.assertTrue(dcond5.numpy()[i] == static_ret5[i]) + class TestBook(LayerTest): def test_all_layers(self):