引用计数智能指针的正确实现

Correct implementation of ref counting smart pointer

简介:

我们正在尝试实现我们自己的引用计数智能指针(类似于 std::shared_ptr)。我们已经有了其他 classes 可以继承的引用计数接口 class。该接口提供 grab()drop() 方法,它们相应地递增和递减引用计数。 Tldr 手动参考计数。我们现在要做的是类似 RAII 的包装器,它在复制(复制构造函数)时调用 grab() 并在析构函数中调用 drop()

代码

这里 source code 介绍了介绍部分中描述的智能指针的实现。必要的部分也贴在下面:

    //IReferenceCounted is the interface class defining grab() and drop() methods incrementing and decrementing ref count

    template<class I_REFERENCE_COUNTED>
    class smart_refctd_ptr
    {
            static_assert(std::is_base_of<IReferenceCounted, I_REFERENCE_COUNTED>::value,"Wrong Base Class!");
            
            mutable I_REFERENCE_COUNTED* ptr; // since IReferenceCounted declares the refcount mutable atomic
            template<class U> friend class smart_refctd_ptr;
        public:
            constexpr smart_refctd_ptr() noexcept : ptr(nullptr) {}
            constexpr smart_refctd_ptr(std::nullptr_t) noexcept : ptr(nullptr) {}
            template<class U>
            explicit smart_refctd_ptr(U* _pointer) noexcept : ptr(_pointer)
            {
                if (_pointer)
                    _pointer->grab();
            }
            template<class U>
            explicit smart_refctd_ptr(U* _pointer, dont_grab_t t) noexcept : ptr(_pointer) {}
            template<class U>
            smart_refctd_ptr(const smart_refctd_ptr<U>& other) noexcept : smart_refctd_ptr(other.ptr) {}
            template<class U>
            smart_refctd_ptr(smart_refctd_ptr<U>&& other) noexcept : smart_refctd_ptr()
            {
                if (ptr) // should only happen if constexpr (is convertible)
                    ptr->drop();
                ptr = other.ptr;
                other.ptr = nullptr; // should only happen if constexpr (is convertible)
            }
            ~smart_refctd_ptr() noexcept
            {
                if (ptr)
                    ptr->drop();
            }

            template<class U>
            inline smart_refctd_ptr& operator=(U* _pointer) noexcept
            {
                if (_pointer)
                    _pointer->grab();
                if (ptr)
                    ptr->drop();
                ptr = _pointer;
                return *this;
            }
            template<class U>
            inline smart_refctd_ptr& operator=(const smart_refctd_ptr<U>& other) noexcept
            {
                return operator=(other.ptr);
            }
            template<class U>
            inline smart_refctd_ptr& operator=(smart_refctd_ptr<U>&& other) noexcept
            {
                if (ptr) // should only happen if constexpr (is convertible)
                    ptr->drop();
                ptr = other.ptr;
                other.ptr = nullptr; // should only happen if constexpr (is convertible)
                return *this;
            }

            inline I_REFERENCE_COUNTED* get() { return ptr; }
            inline const I_REFERENCE_COUNTED* get() const { return ptr; }

            inline I_REFERENCE_COUNTED* operator->() { return ptr; }
            inline const I_REFERENCE_COUNTED* operator->() const { return ptr; }

            inline I_REFERENCE_COUNTED& operator*() { return *ptr; }
            inline const I_REFERENCE_COUNTED& operator*() const { return *ptr; }

            inline I_REFERENCE_COUNTED& operator[](size_t idx) { return ptr[idx]; }
            inline const I_REFERENCE_COUNTED& operator[](size_t idx) const { return ptr[idx]; }


            inline explicit operator bool() const { return ptr; }
            inline bool operator!() const { return !ptr; }

            template<class U>
            inline bool operator==(const smart_refctd_ptr<U> &other) const { return ptr == other.ptr; }
            template<class U>
            inline bool operator!=(const smart_refctd_ptr<U> &other) const { return ptr != other.ptr; }

            template<class U>
            inline bool operator<(const smart_refctd_ptr<U> &other) const { return ptr < other.ptr; }
            template<class U>
            inline bool operator>(const smart_refctd_ptr<U>& other) const { return ptr > other.ptr; }
    };

问题

事实证明,它并不像看起来那么容易,因为我认为是复制省略。复制时不调用我们的智能指针的复制构造函数。我在 Visual Studio 2017 年工作,似乎甚至没有从复制构造函数 C++ 代码生成程序集:我无法在其中放置断点,也看不到为其生成的任何 asm。 即使是这两行简单的 C++ 代码...

    core::smart_refctd_ptr<IAsset> mesh_ptr(mesh);
    core::smart_refctd_ptr<IAsset> mesh_ptr2 = mesh_ptr;

...我得到这个 asm:

;    core::smart_refctd_ptr<IAsset> mesh_ptr(mesh);
 mov         rdx,qword ptr [mesh]  
 lea         rcx,[mesh_ptr]  
 call        irr::core::smart_refctd_ptr<irr::asset::IAsset>::smart_refctd_ptr<irr::asset::IAsset><irr::asset::SCPUMesh> (013F2A79B7h)  
 nop  
 ;   core::smart_refctd_ptr<IAsset> mesh_ptr2 = mesh_ptr;
 mov         rax,qword ptr [mesh_ptr]  
 mov         qword ptr [mesh_ptr2],rax  

根本不调用复制构造函数。因此 ref count 没有增加,但是调用了 2 个析构函数,这显然导致减少的数量多于增加的数量。这发生在 使用 /Od 标志 的调试构建中(甚至还没有尝试完全优化)。 因为它应该与 std::shared_ptr 非常相似,所以我在我的编译器(如上所述的 VS2017)中查找了它的实现。但是,我真的看不到任何可以告诉我解决方案的技巧。我试过 - std::shared_ptr 工作正常(引用计数在副本上正确增加),但我不知道为什么。为什么 shared_ptr 有效而我们的 smart_refctd_ptr 无效?

解决方案是创建非模板复制构造函数。用户定义的复制构造函数必须是非模板构造函数,如 12.8.2 [class.copy] 中的 C++14 标准状态。同样的事情需要用移动构造函数、复制赋值运算符和移动赋值运算符来完成。所以复制构造函数现在看起来像这样:

            template<class U>
            void copy(const smart_refctd_ptr<U>& other) noexcept
            {
                if (other.ptr)
                    other.ptr->grab();
                ptr = other.ptr;
            }

            template<class U, std::enable_if_t<!std::is_same<U,I_REFERENCE_COUNTED>::value, int> = 0>
            smart_refctd_ptr(const smart_refctd_ptr<U>& other) noexcept
            {
                this->copy(other);
            }
            smart_refctd_ptr(const smart_refctd_ptr<I_REFERENCE_COUNTED>& other) noexcept
            {
                this->copy(other);
            }

请注意,很可能不需要 std::enable_if_t<...,int> = 0 模板参数。