aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2018-03-21 09:08:30 +0000
committerRob Austein <sra@hactrn.net>2018-03-21 09:08:30 +0000
commit9a956ed5a42301ee1efb5642cc0f381751d917f5 (patch)
tree3f19b930af198ed1889d15b9b574e7ba46aa931c
parent894181009ad3002d84d2ce6ea74bbd5aea068999 (diff)
Supply our own context manager instead of using contextlib.
contextlib is cute, but incompatible with other coroutine schemes like Tornado, so just write our own context manager for xdrlib.Unpacker.
-rw-r--r--cryptech/libhal.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/cryptech/libhal.py b/cryptech/libhal.py
index acd1abb..273a8a0 100644
--- a/cryptech/libhal.py
+++ b/cryptech/libhal.py
@@ -43,7 +43,6 @@ import uuid
import xdrlib
import socket
import logging
-import contextlib
logger = logging.getLogger(__name__)
@@ -408,6 +407,15 @@ class PKey(Handle):
def import_pkey(self, pkcs8, kek, flags = 0):
return self.hsm.pkey_import(kekek = self, pkcs8 = pkcs8, kek = kek, flags = flags)
+class ContextManagedUnpacker(xdrlib.Unpacker):
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.done()
+
+
class HSM(object):
mixed_mode = False
@@ -429,7 +437,7 @@ class HSM(object):
logger.debug("send: %s", ":".join("{:02x}".format(ord(c)) for c in msg))
self.socket.sendall(msg)
- def _recv(self, code): # Returns an xdrlib.Unpacker
+ def _recv(self, code): # Returns a ContextManagedUnpacker
closed = False
while True:
msg = [self.sockfile.read(1)]
@@ -442,7 +450,7 @@ class HSM(object):
msg = slip_decode("".join(msg))
if not msg:
continue
- msg = xdrlib.Unpacker("".join(msg))
+ msg = ContextManagedUnpacker("".join(msg))
if msg.unpack_uint() != code:
continue
return msg
@@ -480,7 +488,6 @@ class HSM(object):
self._pack_arg(packer, name)
self._pack_arg(packer, HAL_PKEY_ATTRIBUTE_NIL if value is None else value)
- @contextlib.contextmanager
def rpc(self, code, *args, **kwargs):
client = kwargs.get("client", 0)
packer = xdrlib.Packer()
@@ -491,8 +498,7 @@ class HSM(object):
unpacker = self._recv(code)
client = unpacker.unpack_uint()
self._raise_if_error(unpacker.unpack_uint())
- yield unpacker
- unpacker.done()
+ return unpacker
def get_version(self):
with self.rpc(RPC_FUNC_GET_VERSION) as r: