diff options
author | Rob Austein <sra@hactrn.net> | 2018-03-21 09:08:30 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2018-03-21 09:08:30 +0000 |
commit | 9a956ed5a42301ee1efb5642cc0f381751d917f5 (patch) | |
tree | 3f19b930af198ed1889d15b9b574e7ba46aa931c | |
parent | 894181009ad3002d84d2ce6ea74bbd5aea068999 (diff) |
Supply our own context manager instead of using contextlib.
contextlib is cute, but incompatible with other coroutine schemes like
Tornado, so just write our own context manager for xdrlib.Unpacker.
-rw-r--r-- | cryptech/libhal.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/cryptech/libhal.py b/cryptech/libhal.py index acd1abb..273a8a0 100644 --- a/cryptech/libhal.py +++ b/cryptech/libhal.py @@ -43,7 +43,6 @@ import uuid import xdrlib import socket import logging -import contextlib logger = logging.getLogger(__name__) @@ -408,6 +407,15 @@ class PKey(Handle): def import_pkey(self, pkcs8, kek, flags = 0): return self.hsm.pkey_import(kekek = self, pkcs8 = pkcs8, kek = kek, flags = flags) +class ContextManagedUnpacker(xdrlib.Unpacker): + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.done() + + class HSM(object): mixed_mode = False @@ -429,7 +437,7 @@ class HSM(object): logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg)) self.socket.sendall(msg) - def _recv(self, code): # Returns an xdrlib.Unpacker + def _recv(self, code): # Returns a ContextManagedUnpacker closed = False while True: msg = [self.sockfile.read(1)] @@ -442,7 +450,7 @@ class HSM(object): msg = slip_decode("".join(msg)) if not msg: continue - msg = xdrlib.Unpacker("".join(msg)) + msg = ContextManagedUnpacker("".join(msg)) if msg.unpack_uint() != code: continue return msg @@ -480,7 +488,6 @@ class HSM(object): self._pack_arg(packer, name) self._pack_arg(packer, HAL_PKEY_ATTRIBUTE_NIL if value is None else value) - @contextlib.contextmanager def rpc(self, code, *args, **kwargs): client = kwargs.get("client", 0) packer = xdrlib.Packer() @@ -491,8 +498,7 @@ class HSM(object): unpacker = self._recv(code) client = unpacker.unpack_uint() self._raise_if_error(unpacker.unpack_uint()) - yield unpacker - unpacker.done() + return unpacker def get_version(self): with self.rpc(RPC_FUNC_GET_VERSION) as r: |