From 9a956ed5a42301ee1efb5642cc0f381751d917f5 Mon Sep 17 00:00:00 2001 From: Rob Austein Date: Wed, 21 Mar 2018 09:08:30 +0000 Subject: Supply our own context manager instead of using contextlib. contextlib is cute, but incompatible with other coroutine schemes like Tornado, so just write our own context manager for xdrlib.Unpacker. --- cryptech/libhal.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cryptech/libhal.py b/cryptech/libhal.py index acd1abb..273a8a0 100644 --- a/cryptech/libhal.py +++ b/cryptech/libhal.py @@ -43,7 +43,6 @@ import uuid import xdrlib import socket import logging -import contextlib logger = logging.getLogger(__name__) @@ -408,6 +407,15 @@ class PKey(Handle): def import_pkey(self, pkcs8, kek, flags = 0): return self.hsm.pkey_import(kekek = self, pkcs8 = pkcs8, kek = kek, flags = flags) +class ContextManagedUnpacker(xdrlib.Unpacker): + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.done() + + class HSM(object): mixed_mode = False @@ -429,7 +437,7 @@ class HSM(object): logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) self.socket.sendall(msg) - def _recv(self, code): # Returns an xdrlib.Unpacker + def _recv(self, code): # Returns a ContextManagedUnpacker closed = False while True: msg = [self.sockfile.read(1)] @@ -442,7 +450,7 @@ class HSM(object): msg = slip_decode("".join(msg)) if not msg: continue - msg = xdrlib.Unpacker("".join(msg)) + msg = ContextManagedUnpacker("".join(msg)) if msg.unpack_uint() != code: continue return msg @@ -480,7 +488,6 @@ class HSM(object): self._pack_arg(packer, name) self._pack_arg(packer, HAL_PKEY_ATTRIBUTE_NIL if value is None else value) - @contextlib.contextmanager def rpc(self, code, *args, **kwargs): client = kwargs.get("client", 0) packer = xdrlib.Packer() @@ -491,8 +498,7 @@ class HSM(object): unpacker = self._recv(code) client = unpacker.unpack_uint() self._raise_if_error(unpacker.unpack_uint()) - yield unpacker - unpacker.done() + return unpacker def get_version(self): with self.rpc(RPC_FUNC_GET_VERSION) as r: -- cgit v1.2.3