Making Unsafe Rust Safe

One of Rust's greatest strengths is its guarantee that if it compiles, it's free of data races and thread-safety violations. To use unsafe code and maintain that guarantee requires you to write safe wrappers that perfectly define the contracts under which the unsafe code can be used.

Making Unsafe Rust Safe

One of Rust's greatest strengths is its guarantee that if it compiles, it's free of data races and thread-safety violations. However, all bets are off if you use unsafe code, for example, to integrate with a C library. To use unsafe and maintain the above guarantee, you need to make especially sure that your unsafe code is wrapped by code that perfectly defines the contracts under which it can be used.

A Completely Unsafe API

Let's say we want to create safe bindings for a library that has no thread-safety guarantees whatsoever:

/// The "sys" module is a hypothetical C API or "sys" crate that is
/// intended to be used from only one thread.
mod sys {
    pub unsafe fn foo() {
        // maybe this implementation modifies some sort of global state
    }
    pub unsafe fn bar() {
        // maybe this implementation modifies some sort of global state
    }
}

To make this safe, we'll need to make a type that wraps the API. To make sure this type can't be created simultaneously on multiple threads it will have to be a singleton:

mod api {
    use std::sync::atomic::{AtomicBool, Ordering};
    use super::sys;

    static API_EXISTS: AtomicBool = AtomicBool::new(false);
    
    pub struct API;
    
    impl API {
        pub fn get() -> Option<API> {
            match API_EXISTS.compare_and_swap(false, true, Ordering::Acquire) {
                true => None,
                false => Some(API)
            }
        }
        
        pub fn foo(&mut self) {
            unsafe {
                sys::foo()
            }
        }

        pub fn bar(&mut self) {
            unsafe {
                sys::bar()
            }
        }
    }
    
    impl Drop for API {
        fn drop(&mut self) {
            API_EXISTS.store(false, Ordering::Release)
        }
    }
}

We can then use it like so:

fn main() {
    use api::API;

    {
        // Getting the API the first time should work.
        let mut api = API::get().unwrap();
        api.foo();
        api.bar();
        
        // Getting the API a second time should fail.
        assert_eq!(API::get().is_some(), false);
    }
    
    // Getting the API after it was dropped should work again.
    let mut api = API::get().unwrap();
    api.foo();
}
Run this in the Playground

It's not possible to have more than one instance of API and to use API, we must have a mutable reference to it. This guarantees that the API can't be used simultaneously from multiple threads. If needed, we can use an Arc<Mutex<API>> to allow safe, synchronized access from multiple threads.

Initialization

Many C libraries require an initialization function to be called before use:

/// The "sys" module is a hypothetical C API or "sys" crate that is
/// thread-safe, but requires initialization and cleanup.
mod sys {
    pub unsafe fn init() {}
    pub unsafe fn foo() {
        // maybe this implementation reads state created by init
    }
    pub unsafe fn cleanup() {}
}

In this case, it's okay for multiple threads to invoke foo simultaneously, but to make sure init gets invoked first, we still need a wrapper type for the API:

mod api {
    use std::sync::{Arc, Mutex, Weak};
    use super::sys;

    struct APIImpl;

    lazy_static! {
        static ref GLOBAL_API_IMPL: Mutex<Option<Weak<APIImpl>>> = Mutex::new(None);
    }
    
    impl Drop for APIImpl {
        fn drop(&mut self) {
            unsafe {
                sys::cleanup()
            }
        }
    }

    pub struct API(Arc<APIImpl>);
    
    impl API {
        pub fn new() -> API {
            let mut api = GLOBAL_API_IMPL.lock().unwrap();
            let existing = (*api).as_ref().and_then(|api| api.upgrade());
            match existing {
                Some(api) => API(api.clone()),
                None => {
                    unsafe {
                        sys::init();
                    }
                    let new_api = Arc::new(APIImpl);
                    *api = Some(Arc::downgrade(&new_api));
                    API(new_api)
                }
            }
        }
        
        pub fn foo(&self) {
            unsafe {
                sys::foo()
            }
        }
    }
}

With this implementation, each API owns a strong reference to an APIImpl struct. If there is no existing APIImpl, the first API creates it and invokes init. When all APIs are dropped, the APIImpl is also dropped, and cleanup is invoked. This is made possible with a global weak reference to the APIImpl created via the lazy_static crate.

Users of this API no longer even need to be aware of init or cleanup's existence:

fn main() {
    let api = api::API::new();
    api.foo();
}
Run this in the Playground

Callbacks

Let's say our unsafe C API allows us to register a global callback with some opaque pointer that gets passed back to the callback any time it's invoked:

/// The "sys" module is a hypothetical C API or "sys" crate that is
/// not thread-safe.
mod sys {
    pub use libc::c_void as void;

    static mut CB: Option<(unsafe extern "C" fn(*const void), *const void)> = None;

    pub unsafe fn set_callback(
        cb: Option<unsafe extern "C" fn(*const void)>,
        userdata: *const void,
    ) {
        CB = cb.map(|f| (f, userdata));
    }

    pub unsafe fn do_thing() {
        if let Some((f, userdata)) = CB {
            f(userdata)
        }
    }
}
For demonstration purposes, these functions have actual implementations.

We'll be providing a function pointer to set_callback. Typically this is just a pointer to a static function that dispatches to the real callback using userdata. The trick here is in making sure the lifetime of that callback outlives the window of time where it may be called. One way to do this is to give ownership of the callback to a struct that users invoke your API through:

mod api {
    use super::sys;
    use std::{pin::Pin, sync::atomic::{AtomicBool, Ordering}};

    static API_EXISTS: AtomicBool = AtomicBool::new(false);

    pub struct API {
        callback: Option<Pin<Box<Box<dyn FnMut()>>>>,
    }

    impl API {
        pub fn get() -> Option<API> {
            match API_EXISTS.compare_and_swap(false, true, Ordering::Acquire) {
                true => None,
                false => Some(API{
                    callback: None,
                }),
            }
        }

        pub fn set_callback<F: FnMut() + 'static>(&mut self, f: F) {
            unsafe extern "C" fn callback_impl(f: *const sys::void) {
                (*(f as *mut sys::void as *mut Box<dyn FnMut()>))()
            }
            let mut cb: Pin<Box<Box<dyn FnMut()>>> = Box::pin(Box::new(f));
            unsafe {
                sys::set_callback(
                    Some(callback_impl),
                    &mut *cb as *mut Box<dyn FnMut()> as *const _,
                )
            }
            self.callback = Some(cb);
        }

        pub fn do_thing(&mut self) {
            unsafe { sys::do_thing() }
        }
    }

    impl Drop for API {
        fn drop(&mut self) {
            unsafe {
                sys::set_callback(None, std::ptr::null());
            }
            API_EXISTS.store(false, Ordering::Release)
        }
    }
}

Usage looks like this:

fn main() {
    let mut api = api::API::get().unwrap();
    api.set_callback(|| {
        println!("hello from my callback!");
    });
    api.do_thing();
}

But there's one big problem: 'static. The callback given to set_callback obviously has to outlive the API. As a first pass, we required that the callback be 'static, but what if we want to reference local variables from within main?

We can keep our 'static set_callback function as it does have its uses, but to facilitate situations where we need our callback to have access to local or other non-static variables, we can create another method:

pub fn with_callback<'a, F: FnMut() + 'a>(&'a mut self, f: F) -> WithCallback<'a> {
    let mut cb: Pin<Box<Box<dyn FnMut() + 'a>>> = Box::pin(Box::new(f));
    unsafe {
        sys::set_callback(
            Some(callback_impl),
            &mut *cb as *mut Box<dyn FnMut() + 'a> as *const _,
        )
    }
    WithCallback{
        api: self,
        _callback: cb,
    }
}

This method will set the callback, then return an object that owns the callback and holds a mutable reference to the API. When the object is dropped, the callback will be cleared or restored to the previous callback if one was set with set_callback:

pub struct WithCallback<'a> {
    api: &'a mut API,
    _callback: Pin<Box<Box<dyn FnMut() + 'a>>>,
}

impl<'a> WithCallback<'a> {
    pub fn do_thing(&mut self) {
        self.api.do_thing()
    }
}

impl<'a> Drop for WithCallback<'a> {
    fn drop(&mut self) {
        unsafe {
            match &mut self.api.callback {
                Some(cb) => sys::set_callback(
                    Some(callback_impl),
                    &mut **cb as *mut Box<dyn FnMut()> as *const _,
                ),
                None => sys::set_callback(None, std::ptr::null()),
            }
        }
    }
}

It can be used like so:

fn main() {
    let mut api = api::API::get().unwrap();
    api.set_callback(|| {
        println!("hello from my callback!");
    });
    api.do_thing();

    let mut counter = 1;
    api.with_callback(|| {
        println!("hello from my non-static callback!");
        counter += 1;
    }).do_thing();
}
Run this in the Playground

Sendable Objects

Many libraries have object-oriented APIs that involve creating objects, performing operations on them, then deleting them:

/// The "sys" module is a hypothetical C API or "sys" crate that is
/// thread-safe.
mod sys {
    pub type Object = libc::c_void;
    
    /// Creates a new object, which must be deleted via delete_object.
    pub unsafe fn new_object() -> *mut Object {
        unimplemented!()
    }
    
    pub unsafe fn object_foo(_obj: *mut Object) {
        unimplemented!()
    }

    pub unsafe fn delete_object(_obj: *mut Object) {
        unimplemented!()
    }
}

The safe wrapper for this object is straight-forward enough:

mod api {
    use super::sys;

    pub struct Object {
        inner: *mut sys::Object,
    }

    impl Object {
        pub fn new() -> Self {
            Self{
                inner: unsafe {
                    sys::new_object()
                }
            }
        }
        
        pub fn foo(&mut self) {
            unsafe {
                sys::object_foo(self.inner)
            }
        }
    }

    impl Drop for Object {
        fn drop(&mut self) {
            unsafe {
                sys::delete_object(self.inner);
            }
        }
    }
}

This can be used like so:

fn main() {
    let mut obj = api::Object::new();
    obj.foo();
}

Unfortunately, if we try to use this in an asynchronous context, we're probably going to run into errors. For example, let's pretend obj.foo() takes a long time and we need to spawn it on another thread using tokio:

#[tokio::main]
async fn main() {
    let mut obj = api::Object::new();
    tokio::task::spawn_blocking(move || {
        obj.foo();
    }).await.unwrap();
}

If we try to do this, we'll get a compile error because Object isn't Send. By default, mutable pointers are not Send (or Sync). In most cases, if the unsafe code doesn't rely on code running on any particular thread, Send should be implemented:

unsafe impl Send for Object {}

This simply tells Rust that it's okay to use the object on threads other than the one it was created on. Note that this does not require the underlying implementation to be thread-safe; it just can't use thread-specific features like thread-local storage.

Now the above code works:

#[tokio::main]
async fn main() {
    let mut obj = api::Object::new();
    tokio::task::spawn_blocking(move || {
        obj.foo();
    }).await.unwrap();
}
Rust this in the Playground

Typically Sync can be implemented as well, but this would allow multiple threads to use &self simultaneously, which might not be okay depending on the wrapper and unsafe code's implementations.


Share Tweet Send
0 Comments
Loading...

Related Articles