# coding=utf-8
import asyncio
import hashlib
import itertools
import os
import time
import ssl
import urllib
import urllib.parse
from ssl import CertificateError
import urwid
import websockets
import websockets.handshake
from collections import OrderedDict, deque
from websockets.protocol import State
from wsstat.gui import BlinkBoardWidget, LoggerWidget
import logging
import sys
if sys.version_info < (3, 4, 4):
asyncio.ensure_future = asyncio.async
[docs]class ConnectedWebsocketConnection(object):
def __init__(self, ws, identifier):
self.ws = ws
self.id = identifier
self._message_count = itertools.count()
self.last_message_recv = 0
self.started = time.time()
@property
def message_count(self):
return int(repr(self._message_count)[6:-1])
[docs] def increment_message_counter(self):
next(self._message_count)
def __repr__(self):
return "<Websocket {}>".format(self.id)
[docs] def process_message(self, message):
self.increment_message_counter()
self.last_message_recv = time.time()
[docs]class WebsocketTestingClient(object):
"""
Setting up the websocket calls the following callbacks that can be overridden to extend functinality.
For an example see WebsocketTestingClientWithApiTokenHeader
def before_connect(self):
def setup_websocket_connection(self, statedict):
def get_identifier(self, statedict):
def after_connect(self, statedict):
def before_recv(self, statedict):
def after_recv(self, statedict, message):
"""
def __init__(self, websocket_url, **kwargs):
# Configuration stuff
self.frame = None
self.websocket_url = urllib.parse.urlparse(websocket_url)
self.total_connections = kwargs.get('total_connections', 250)
self._exiting = False
self.extra_headers = None
# Asyncio stuff
self.loop = asyncio.get_event_loop()
self.loop.set_exception_handler(self.handle_exceptions)
self.connection_semaphore = asyncio.Semaphore(kwargs.get('max_connecting_sockets', 15))
# Counts and buffers
self.global_message_counter = itertools.count()
self.socket_count = itertools.count(1)
self.sockets = OrderedDict()
self.ring_buffer = deque(maxlen=10)
if kwargs.get('header'):
self.extra_headers = dict([map(lambda x: x.strip(), kwargs['header'].split(':'))])
if kwargs.get('setup_tasks', True):
self.setup_tasks()
self.insecure_connection = kwargs.get('insecure', False)
self.blinkboard = BlinkBoardWidget()
self.logger = LoggerWidget()
self.default_view = urwid.Pile([
self.blinkboard.default_widget,
(10, self.logger.default_widget)
])
self.logger_view = urwid.Pile([
self.logger.logger_widget,
])
self.graph_view = urwid.Pile([
self.logger.graph_widget,
])
self.small_blink_and_graph_view = urwid.Pile([
self.logger.graph_widget,
(10, urwid.LineBox(self.blinkboard.small_blinks)),
])
@property
def messages_per_second(self):
return self._get_current_messages_per_second()
[docs] def log(self, identifier, message):
self.logger.log("[{}] {}".format(identifier, message))
[docs] @asyncio.coroutine
def create_websocket_connection(self):
statedict = self.before_connect()
connection_args = self.setup_websocket_connection(statedict)
# Make len(connection_semaphore) connection attempts at a time
with (yield from self.connection_semaphore):
identifier = self.get_identifier(statedict)
self.log(identifier, 'Connecting to {}'.format(connection_args['uri']))
start_time = time.time()
# Signify that this socket is connecting
self.sockets[identifier] = None
retries = 0
while True:
try:
# Await the connection to complete successfully
websocket = yield from websockets.connect(**connection_args)
websocket.connection_time = time.time() - start_time
break
except BaseException as e:
retries += 1
if isinstance(e, CertificateError):
# If there was an ssl error, bail immediately
self.logger.log("[{}] SSL connection problem! {}".format(identifier, e))
return False
else:
self.logger.log("[{}] {}".format(identifier, e))
if retries > 3:
self.sockets[identifier] = False
if isinstance(e, websockets.InvalidHandshake):
self.sockets[identifier] = e
return False
yield from asyncio.sleep(.25, loop=self.loop)
# Create our handler object
connected_websocket = ConnectedWebsocketConnection(websocket, identifier)
statedict['connected_websocket'] = connected_websocket
# Update the connected_sockets table
self.sockets[identifier] = connected_websocket
# Log that we connected successfully
self.logger.log("[{}] Connected in {:.4f} ms".format(connected_websocket.id, websocket.connection_time * 1000.00))
self.after_connect(statedict)
try:
# Just loop and recv messages
while True:
if self._exiting:
yield from websocket.close()
return True
self.before_recv(statedict)
# Wait for a new message
message = yield from websocket.recv()
self.after_recv(statedict, message)
# Increment our counters
next(self.global_message_counter)
connected_websocket.process_message(message)
except Exception as e:
# Log the exception
self.logger.log("[{}] {}".format(connected_websocket.id, e))
return False
[docs] @asyncio.coroutine
def update_urwid(self):
interval = .1
status_line = "{hostname} | Connections: [{current}/{total}] | Total Messages: {message_count} | Messages/Second: {msgs_per_second}/s"
while True:
if self._exiting:
return True
#raise urwid.ExitMainLoop
# Only update things a max of 10 times/second
yield from asyncio.sleep(interval)
# Get the current global message count
global_message_count = int(repr(self.global_message_counter)[6:-1])
self.ring_buffer.append(global_message_count)
currently_connected_sockets = len([x for x in self.sockets.values() if x and not isinstance(x, BaseException) and x.ws.state == State.OPEN])
self.logger.update_graph_data([self.messages_per_second,])
# Get and update our blinkboard widget
self.blinkboard.generate_blinkers(self.sockets)
# Make the status message
status_message = status_line.format(
hostname=self.websocket_url.netloc,
current=currently_connected_sockets,
total=self.total_connections,
message_count=global_message_count,
msgs_per_second=self.messages_per_second
)
self.frame.footer.set_text(status_message)
[docs] def setup_tasks(self):
tasks = []
for _ in range(self.total_connections):
coro = self.create_websocket_connection()
tasks.append(asyncio.ensure_future(coro))
update_urwid_coro = self.update_urwid()
tasks.append(asyncio.ensure_future(update_urwid_coro))
# Gather all the tasks needed
self.coros = tasks
[docs] def exit(self):
self._exiting = True
import sys
sys.exit(0)
[docs] def handle_keypresses(self, keypress):
if keypress == "q" or keypress == 'ctrl c':
self.exit()
keymap = {
"l": self.logger_view,
"g": self.graph_view,
"tab": self.small_blink_and_graph_view,
"esc": self.default_view
}
try:
requested_view = keymap[keypress]
except KeyError:
return True
if self.frame.body == requested_view:
self.frame.body = self.default_view
else:
self.frame.body = requested_view
return True
def _get_current_messages_per_second(self):
# Calculate deltas over the past window
deltas = [y - x for x, y in zip(list(self.ring_buffer), list(self.ring_buffer)[1:])]
# If the deque isn't empty
if deltas:
msgs_per_second = '{0:.2f}'.format(float(sum(deltas) / len(self.ring_buffer)) * 10)
else:
msgs_per_second = '{0:.2f}'.format(float(0.0))
return msgs_per_second
[docs] def before_recv(self, statedict):
pass
[docs] def after_recv(self, statedict, message):
pass
[docs] def before_connect(self):
statedict = {}
return statedict
[docs] def after_connect(self, statedict):
pass
[docs] def setup_websocket_connection(self, statedict):
ws_args = {
"uri": self.websocket_url.geturl(),
"extra_headers": self.extra_headers
}
if self.insecure_connection:
ws_args['ssl'] = ssl._create_unverified_context()
return ws_args
[docs] def get_identifier(self, statedict):
return next(self.socket_count)
[docs] def handle_exceptions(self, loop, context):
logging.error("Exception! : {}".format(str(context.get('exception'))))
[docs]class WebsocketTestingClientWithRandomApiTokenHeader(WebsocketTestingClient):
"""
Introduces a new parameter: `header_name` - used to specify the key to 'extra_headers' passed to `websocket.connect`
"""
def __init__(self, *args, **kwargs):
self.header_name = kwargs.pop("header_name", 'x-endpoint-token')
super().__init__(*args, **kwargs)
[docs] def before_connect(self):
statedict = super().before_connect()
# Generate a random API token
statedict['api_token'] = hashlib.sha256(os.urandom(4)).hexdigest()
return statedict
[docs] def setup_websocket_connection(self, statedict):
args = super().setup_websocket_connection(statedict)
args['extra_headers'] = {
self.header_name: statedict['api_token']
}
return args
[docs] def get_identifier(self, statedict):
return statedict['api_token'][:8]