// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template typename std::enable_if::value>::type CopyVecotorToTensor( const char* value_name, framework::Tensor* out, const framework::ExecutionContext& ctx) { // If attribute value dtype is vector, it will be converted to // vector. // at the same time, we can not use vector to hold the value, because // the c++ use bit value to replace byte value. auto values = ctx.Attr>(value_name); framework::TensorFromVector(values, ctx.device_context(), out); // use the array to replace to vector bool* array_ptr = new T[values.size()]; for (unsigned int i = 0; i < values.size(); i++) { array_ptr[i] = static_cast(values[i]); } framework::TensorFromArray(array_ptr, values.size(), ctx.device_context(), out); delete[] array_ptr; } template typename std::enable_if::value>::type CopyVecotorToTensor(const char* value_name, framework::Tensor* out, const framework::ExecutionContext& ctx) { auto values = ctx.Attr>(value_name); framework::TensorFromVector(values, ctx.device_context(), out); } template class AssignValueKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto shape = ctx.Attr>("shape"); auto* out = ctx.Output("Out"); int dtype = ctx.Attr("dtype"); const char* value_name = nullptr; switch (dtype) { case framework::proto::VarType::BOOL: value_name = "bool_values"; break; case framework::proto::VarType::INT32: value_name = "int32_values"; break; case framework::proto::VarType::FP32: value_name = "fp32_values"; break; case framework::proto::VarType::INT64: value_name = "int64_values"; break; default: PADDLE_THROW("Unsupported dtype for assign_value_op: %d", dtype); break; } CopyVecotorToTensor(value_name, out, ctx); out->Resize(framework::make_ddim(shape)); } }; } // namespace operators } // namespace paddle