Source code for pywb.warcserver.inputrequest

from warcio.limitreader import LimitReader
from warcio.statusandheaders import StatusAndHeadersParser
from pywb.warcserver.amf import Amf
from pyamf.remoting import decode
from warcio.utils import to_native_str

from six.moves.urllib.parse import urlsplit, quote, unquote_plus, urlencode
from six import iteritems, StringIO, PY3
from io import BytesIO

import base64
import cgi
import json


#=============================================================================
[docs]class DirectWSGIInputRequest(object): def __init__(self, env): self.env = env
[docs] def get_req_method(self): return self.env['REQUEST_METHOD'].upper()
[docs] def get_req_protocol(self): return self.env['SERVER_PROTOCOL']
[docs] def get_referrer(self): return self._get_header('Referer')
[docs] def get_req_headers(self): headers = {} for name, value in iteritems(self.env): # will be set by requests to match actual host if name == 'HTTP_HOST': continue elif name.startswith('HTTP_'): name = name[5:].title().replace('_', '-') elif name in ('CONTENT_LENGTH', 'CONTENT_TYPE'): name = name.title().replace('_', '-') else: value = None if value: headers[name] = value return headers
[docs] def get_req_body(self): input_ = self.env['wsgi.input'] len_ = self._get_content_length() enc = self._get_header('Transfer-Encoding') if len_: data = LimitReader(input_, int(len_)) elif enc: data = input_ else: data = None return data
def _get_content_type(self): return self.env.get('CONTENT_TYPE') def _get_content_length(self): return self.env.get('CONTENT_LENGTH') def _get_header(self, name): return self.env.get('HTTP_' + name.upper().replace('-', '_'))
[docs] def include_method_query(self, url): if not url: return url method = self.get_req_method() if method == 'GET' or method == 'HEAD': return url mime = self._get_content_type() length = self._get_content_length() stream = self.env['wsgi.input'] buffered_stream = BytesIO() query = MethodQueryCanonicalizer(method, mime, length, stream, buffered_stream=buffered_stream, environ=self.env) new_url = query.append_query(url) if new_url != url: self.env['wsgi.input'] = buffered_stream return new_url
[docs] def get_full_request_uri(self): req_uri = self.env.get('REQUEST_URI') if req_uri and not self.env.get('SCRIPT_NAME'): return req_uri req_uri = quote(self.env.get('PATH_INFO', ''), safe='/~!$&\'()*+,;=:@') query = self.env.get('QUERY_STRING') if query: req_uri += '?' + query return req_uri
[docs] def reconstruct_request(self, url=None): buff = StringIO() buff.write(self.get_req_method()) buff.write(' ') buff.write(self.get_full_request_uri()) buff.write(' ') buff.write(self.get_req_protocol()) buff.write('\r\n') headers = self.get_req_headers() if url: parts = urlsplit(url) buff.write('Host: ') buff.write(parts.netloc) buff.write('\r\n') for name, value in iteritems(headers): if name.lower() == 'host': continue buff.write(name) buff.write(': ') buff.write(value) buff.write('\r\n') buff.write('\r\n') buff = buff.getvalue().encode('latin-1') body = self.get_req_body() if body: buff += body.read() return buff
#=============================================================================
[docs]class POSTInputRequest(DirectWSGIInputRequest): def __init__(self, env): self.env = env parser = StatusAndHeadersParser([], verify=False) self.status_headers = parser.parse(self.env['wsgi.input'])
[docs] def get_req_method(self): return self.status_headers.protocol
[docs] def get_req_headers(self): headers = {} for n, v in self.status_headers.headers: headers[n] = v return headers
[docs] def get_full_request_uri(self): return self.status_headers.statusline.split(' ', 1)[0]
[docs] def get_req_protocol(self): return self.status_headers.statusline.split(' ', 1)[-1]
def _get_content_type(self): return self.status_headers.get_header('Content-Type') def _get_content_length(self): return self.status_headers.get_header('Content-Length') def _get_header(self, name): return self.status_headers.get_header(name)
# ============================================================================
[docs]class MethodQueryCanonicalizer(object): #MAX_POST_SIZE = 16384 MAX_QUERY_LENGTH = 4096 def __init__(self, method, mime, length, stream, buffered_stream=None, environ=None): """ Append the method for HEAD/OPTIONS as __pywb_method=<method> For POST requests, requests extract a url-encoded form from stream read content length and convert to query params, if possible Attempt to decode application/x-www-form-urlencoded or multipart/*, otherwise read whole block and b64encode """ self.query = b'' method = method.upper() self.method = method if method != 'POST' and method != 'PUT': return try: length = int(length) except (ValueError, TypeError): return if length <= 0: return # always read entire POST request, but limit query string later #length = min(length, self.MAX_POST_SIZE) query = [] while length > 0: buff = stream.read(length) length -= len(buff) if not buff: break query.append(buff) query = b''.join(query) if buffered_stream: buffered_stream.write(query) buffered_stream.seek(0) if not mime: mime = '' def handle_binary(query): query = base64.b64encode(query) query = to_native_str(query) query = '__wb_post_data=' + query return query if mime.startswith('application/x-www-form-urlencoded'): try: query = to_native_str(query.decode('utf-8')) query = unquote_plus(query) except UnicodeDecodeError: query = handle_binary(query) elif mime.startswith('multipart/'): env = {'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': mime, 'CONTENT_LENGTH': len(query)} args = dict(fp=BytesIO(query), environ=env, keep_blank_values=True) if PY3: args['encoding'] = 'utf-8' try: data = cgi.FieldStorage(**args) except ValueError: # Content-Type multipart/form-data may lack "boundary" info query = handle_binary(query) else: values = [] for item in data.list: values.append((item.name, item.value)) query = urlencode(values, True) elif mime.startswith('application/x-amf'): query = self.amf_parse(query, environ) elif mime.startswith('application/json'): try: query = self.json_parse(query) except Exception as e: print(e) query = '' elif mime.startswith('text/plain'): try: query = self.json_parse(query) except Exception as e: query = handle_binary(query) else: query = handle_binary(query) if query: self.query = query[:self.MAX_QUERY_LENGTH]
[docs] def amf_parse(self, string, warn_on_error): try: res = decode(BytesIO(string)) return urlencode({"request": Amf.get_representation(res)}) except Exception as e: import traceback traceback.print_exc() print(e) return None
[docs] def json_parse(self, string): data = {} dupes = {} def get_key(n): if n not in data: return n if n not in dupes: dupes[n] = 1 dupes[n] += 1 return n + "." + str(dupes[n]) + "_"; def _parser(dict_var): for n, v in dict_var.items(): if isinstance(v, dict): _parser(v) else: data[get_key(n)] = str(v) _parser(json.loads(string)) return urlencode(data)
[docs] def append_query(self, url): if self.method == 'GET': return url if '?' not in url: append_str = '?' else: append_str = '&' append_str += "__wb_method=" + self.method if self.query: append_str += '&' + self.query return url + append_str