// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Runtime.InteropServices; using System.Threading; namespace Microsoft.CodeAnalysis.Test.Utilities { /// /// This factory creates COM "blind aggregator" instances in managed code. /// public static class BlindAggregatorFactory { public static unsafe IntPtr CreateWrapper() { return (IntPtr)BlindAggregator.CreateInstance(); } public static unsafe void SetInnerObject(IntPtr wrapperUnknown, IntPtr innerUnknown, IntPtr managedObjectGCHandlePtr) { BlindAggregator* pWrapper = (BlindAggregator*)wrapperUnknown; pWrapper->SetInnerObject(innerUnknown, managedObjectGCHandlePtr); } /// /// A blind aggregator instance. It is allocated in native memory. /// [StructLayout(LayoutKind.Sequential)] private struct BlindAggregator { private IntPtr _vfPtr; // Pointer to the virtual function table private int _refCount; // COM reference count private IntPtr _innerUnknown; // CCW for the managed object supporting aggregation private IntPtr _gcHandle; // The GC Handle to the managed object (the non aggregated object) public static unsafe BlindAggregator* CreateInstance() { BlindAggregator* pResult = (BlindAggregator*)Marshal.AllocCoTaskMem(sizeof(BlindAggregator)); if (pResult != null) { pResult->Construct(); } return pResult; } private void Construct() { _vfPtr = VTable.AddressOfVTable; _refCount = 1; _innerUnknown = IntPtr.Zero; _gcHandle = IntPtr.Zero; } public void SetInnerObject(IntPtr innerUnknown, IntPtr gcHandle) { _innerUnknown = innerUnknown; Marshal.AddRef(_innerUnknown); _gcHandle = gcHandle; } private void FinalRelease() { Marshal.Release(_innerUnknown); if (_gcHandle != IntPtr.Zero) { GCHandle.FromIntPtr(_gcHandle).Free(); _gcHandle = IntPtr.Zero; } } private unsafe delegate int QueryInterfaceDelegateType(BlindAggregator* pThis, [In] ref Guid riid, out IntPtr pvObject); private unsafe delegate uint AddRefDelegateType(BlindAggregator* pThis); private unsafe delegate uint ReleaseDelegateType(BlindAggregator* pThis); private unsafe delegate int GetGCHandlePtrDelegateType(BlindAggregator* pThis, out IntPtr pResult); [StructLayout(LayoutKind.Sequential)] private struct VTable { // Need these to keep the delegates alive private static unsafe readonly QueryInterfaceDelegateType s_queryInterface = BlindAggregator.QueryInterface; private static unsafe readonly AddRefDelegateType s_addRef = BlindAggregator.AddRef; private static unsafe readonly ReleaseDelegateType s_release = BlindAggregator.Release; private static unsafe readonly GetGCHandlePtrDelegateType s_get_GCHandlePtr = BlindAggregator.GetGCHandlePtr; private IntPtr _queryInterfacePtr; private IntPtr _addRefPtr; private IntPtr _releasePtr; private IntPtr _getGCHandlePtr; private void Construct() { _queryInterfacePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_queryInterface); _addRefPtr = Marshal.GetFunctionPointerForDelegate(VTable.s_addRef); _releasePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_release); _getGCHandlePtr = Marshal.GetFunctionPointerForDelegate(VTable.s_get_GCHandlePtr); } /// /// A 'holder' for a native memory allocation. The allocation is freed in the finalizer. /// private class CoTaskMemPtr { public readonly IntPtr VTablePtr; public unsafe CoTaskMemPtr() { var ptr = Marshal.AllocCoTaskMem(sizeof(VTable)); this.VTablePtr = ptr; ((VTable*)ptr)->Construct(); } ~CoTaskMemPtr() { Marshal.FreeCoTaskMem(this.VTablePtr); } } // Singleton instance of the VTable allocated in native memory. Since it's static, the // underlying native memory will be freed when finalizers run at shutdown. private static readonly CoTaskMemPtr s_instance = new CoTaskMemPtr(); public static IntPtr AddressOfVTable { get { return s_instance.VTablePtr; } } } private const int S_OK = 0; private const int E_NOINTERFACE = unchecked((int)0x80004002); // 00000000-0000-0000-C000-000000000046 private static readonly Guid s_IUnknownInterfaceGuid = new Guid(0x00000000, 0x0000, 0x0000, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46); // 00000003-0000-0000-C000-000000000046 private static readonly Guid s_IMarshalInterfaceGuid = new Guid(0x00000003, 0x0000, 0x0000, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46); // CBD71F2C-6BC5-4932-B851-B93EB3151386 private static readonly Guid s_IComWrapperGuid = new Guid("CBD71F2C-6BC5-4932-B851-B93EB3151386"); private static unsafe int QueryInterface(BlindAggregator* pThis, [In] ref Guid riid, out IntPtr pvObject) { if (riid == s_IUnknownInterfaceGuid || riid == s_IComWrapperGuid) { AddRef(pThis); pvObject = (IntPtr)pThis; return S_OK; } else if (riid == s_IMarshalInterfaceGuid) { pvObject = IntPtr.Zero; return E_NOINTERFACE; } else { // We don't know what the interface is, so aggregate blindly from here return Marshal.QueryInterface(pThis->_innerUnknown, ref riid, out pvObject); } } private static unsafe uint AddRef(BlindAggregator* pThis) { return unchecked((uint)Interlocked.Increment(ref pThis->_refCount)); } private static unsafe uint Release(BlindAggregator* pThis) { uint result = unchecked((uint)Interlocked.Decrement(ref pThis->_refCount)); if (result == 0u) { pThis->FinalRelease(); Marshal.FreeCoTaskMem((IntPtr)pThis); } return result; } private static unsafe int GetGCHandlePtr(BlindAggregator* pThis, out IntPtr pResult) { pResult = pThis->_gcHandle; return S_OK; } } } }