##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################

from collections import deque
import socket
import sys
import threading
import time

from .buffers import ReadOnlyFileBasedBuffer
from .utilities import build_http_date, logger, queue_logger

rename_headers = {  # or keep them without the HTTP_ prefix added
    "CONTENT_LENGTH": "CONTENT_LENGTH",
    "CONTENT_TYPE": "CONTENT_TYPE",
}

hop_by_hop = frozenset(
    (
        "connection",
        "keep-alive",
        "proxy-authenticate",
        "proxy-authorization",
        "te",
        "trailers",
        "transfer-encoding",
        "upgrade",
    )
)


class ThreadedTaskDispatcher:
    """A Task Dispatcher that creates a thread for each task."""

    stop_count = 0  # Number of threads that will stop soon.
    active_count = 0  # Number of currently active threads
    logger = logger
    queue_logger = queue_logger

    def __init__(self):
        self.threads = set()
        self.queue = deque()
        self.lock = threading.Lock()
        self.queue_cv = threading.Condition(self.lock)
        self.thread_exit_cv = threading.Condition(self.lock)

    def start_new_thread(self, target, thread_no):
        t = threading.Thread(
            target=target, name="waitress-{}".format(thread_no), args=(thread_no,)
        )
        t.daemon = True
        t.start()

    def handler_thread(self, thread_no):
        while True:
            with self.lock:
                while not self.queue and self.stop_count == 0:
                    # Mark ourselves as idle before waiting to be
                    # woken up, then we will once again be active
                    self.active_count -= 1
                    self.queue_cv.wait()
                    self.active_count += 1

                if self.stop_count > 0:
                    self.active_count -= 1
                    self.stop_count -= 1
                    self.threads.discard(thread_no)
                    self.thread_exit_cv.notify()
                    break

                task = self.queue.popleft()
            try:
                task.service()
            except BaseException:
                self.logger.exception("Exception when servicing %r", task)

    def set_thread_count(self, count):
        with self.lock:
            threads = self.threads
            thread_no = 0
            running = len(threads) - self.stop_count
            while running < count:
                # Start threads.
                while thread_no in threads:
                    thread_no = thread_no + 1
                threads.add(thread_no)
                running += 1
                self.start_new_thread(self.handler_thread, thread_no)
                self.active_count += 1
                thread_no = thread_no + 1
            if running > count:
                # Stop threads.
                self.stop_count += running - count
                self.queue_cv.notify_all()

    def add_task(self, task):
        with self.lock:
            self.queue.append(task)
            self.queue_cv.notify()
            queue_size = len(self.queue)
            idle_threads = len(self.threads) - self.stop_count - self.active_count
            if queue_size > idle_threads:
                self.queue_logger.warning(
                    "Task queue depth is %d", queue_size - idle_threads
                )

    def shutdown(self, cancel_pending=True, timeout=5):
        self.set_thread_count(0)
        # Ensure the threads shut down.
        threads = self.threads
        expiration = time.time() + timeout
        with self.lock:
            while threads:
                if time.time() >= expiration:
                    self.logger.warning("%d thread(s) still running", len(threads))
                    break
                self.thread_exit_cv.wait(0.1)
            if cancel_pending:
                # Cancel remaining tasks.
                queue = self.queue
                if len(queue) > 0:
                    self.logger.warning("Canceling %d pending task(s)", len(queue))
                while queue:
                    task = queue.popleft()
                    task.cancel()
                self.queue_cv.notify_all()
                return True
        return False


class Task:
    close_on_finish = False
    status = "200 OK"
    wrote_header = False
    start_time = 0
    content_length = None
    content_bytes_written = 0
    logged_write_excess = False
    logged_write_no_body = False
    complete = False
    chunked_response = False
    logger = logger

    def __init__(self, channel, request):
        self.channel = channel
        self.request = request
        self.response_headers = []
        version = request.version
        if version not in ("1.0", "1.1"):
            # fall back to a version we support.
            version = "1.0"
        self.version = version

    def service(self):
        try:
            self.start()
            self.execute()
            self.finish()
        except OSError:
            self.close_on_finish = True
            if self.channel.adj.log_socket_errors:
                raise

    @property
    def has_body(self):
        return not (
            self.status.startswith("1")
            or self.status.startswith("204")
            or self.status.startswith("304")
        )

    def build_response_header(self):
        version = self.version
        # Figure out whether the connection should be closed.
        connection = self.request.headers.get("CONNECTION", "").lower()
        response_headers = []
        content_length_header = None
        date_header = None
        server_header = None
        connection_close_header = None

        for (headername, headerval) in self.response_headers:
            headername = "-".join([x.capitalize() for x in headername.split("-")])

            if headername == "Content-Length":
                if self.has_body:
                    content_length_header = headerval
                else:
                    continue  # pragma: no cover

            if headername == "Date":
                date_header = headerval

            if headername == "Server":
                server_header = headerval

            if headername == "Connection":
                connection_close_header = headerval.lower()
            # replace with properly capitalized version
            response_headers.append((headername, headerval))

        if (
            content_length_header is None
            and self.content_length is not None
            and self.has_body
        ):
            content_length_header = str(self.content_length)
            response_headers.append(("Content-Length", content_length_header))

        def close_on_finish():
            if connection_close_header is None:
                response_headers.append(("Connection", "close"))
            self.close_on_finish = True

        if version == "1.0":
            if connection == "keep-alive":
                if not content_length_header:
                    close_on_finish()
                else:
                    response_headers.append(("Connection", "Keep-Alive"))
            else:
                close_on_finish()

        elif version == "1.1":
            if connection == "close":
                close_on_finish()

            if not content_length_header:
                # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
                # for any response with a status code of 1xx, 204 or 304.

                if self.has_body:
                    response_headers.append(("Transfer-Encoding", "chunked"))
                    self.chunked_response = True

                if not self.close_on_finish:
                    close_on_finish()

            # under HTTP 1.1 keep-alive is default, no need to set the header
        else:
            raise AssertionError("neither HTTP/1.0 or HTTP/1.1")

        # Set the Server and Date field, if not yet specified. This is needed
        # if the server is used as a proxy.
        ident = self.channel.server.adj.ident

        if not server_header:
            if ident:
                response_headers.append(("Server", ident))
        else:
            response_headers.append(("Via", ident or "waitress"))

        if not date_header:
            response_headers.append(("Date", build_http_date(self.start_time)))

        self.response_headers = response_headers

        first_line = "HTTP/%s %s" % (self.version, self.status)
        # NB: sorting headers needs to preserve same-named-header order
        # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
        # rely on stable sort to keep relative position of same-named headers
        next_lines = [
            "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
        ]
        lines = [first_line] + next_lines
        res = "%s\r\n\r\n" % "\r\n".join(lines)

        return res.encode("latin-1")

    def remove_content_length_header(self):
        response_headers = []

        for header_name, header_value in self.response_headers:
            if header_name.lower() == "content-length":
                continue  # pragma: nocover
            response_headers.append((header_name, header_value))

        self.response_headers = response_headers

    def start(self):
        self.start_time = time.time()

    def finish(self):
        if not self.wrote_header:
            self.write(b"")
        if self.chunked_response:
            # not self.write, it will chunk it!
            self.channel.write_soon(b"0\r\n\r\n")

    def write(self, data):
        if not self.complete:
            raise RuntimeError("start_response was not called before body written")
        channel = self.channel
        if not self.wrote_header:
            rh = self.build_response_header()
            channel.write_soon(rh)
            self.wrote_header = True

        if data and self.has_body:
            towrite = data
            cl = self.content_length
            if self.chunked_response:
                # use chunked encoding response
                towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n"
                towrite += data + b"\r\n"
            elif cl is not None:
                towrite = data[: cl - self.content_bytes_written]
                self.content_bytes_written += len(towrite)
                if towrite != data and not self.logged_write_excess:
                    self.logger.warning(
                        "application-written content exceeded the number of "
                        "bytes specified by Content-Length header (%s)" % cl
                    )
                    self.logged_write_excess = True
            if towrite:
                channel.write_soon(towrite)
        elif data:
            # Cheat, and tell the application we have written all of the bytes,
            # even though the response shouldn't have a body and we are
            # ignoring it entirely.
            self.content_bytes_written += len(data)

            if not self.logged_write_no_body:
                self.logger.warning(
                    "application-written content was ignored due to HTTP "
                    "response that may not contain a message-body: (%s)" % self.status
                )
                self.logged_write_no_body = True


class ErrorTask(Task):
    """An error task produces an error response"""

    complete = True

    def execute(self):
        e = self.request.error
        status, headers, body = e.to_response()
        self.status = status
        self.response_headers.extend(headers)
        # We need to explicitly tell the remote client we are closing the
        # connection, because self.close_on_finish is set, and we are going to
        # slam the door in the clients face.
        self.response_headers.append(("Connection", "close"))
        self.close_on_finish = True
        self.content_length = len(body)
        self.write(body.encode("latin-1"))


class WSGITask(Task):
    """A WSGI task produces a response from a WSGI application."""

    environ = None

    def execute(self):
        environ = self.get_environment()

        def start_response(status, headers, exc_info=None):
            if self.complete and not exc_info:
                raise AssertionError(
                    "start_response called a second time without providing exc_info."
                )
            if exc_info:
                try:
                    if self.wrote_header:
                        # higher levels will catch and handle raised exception:
                        # 1. "service" method in task.py
                        # 2. "service" method in channel.py
                        # 3. "handler_thread" method in task.py
                        raise exc_info[1]
                    else:
                        # As per WSGI spec existing headers must be cleared
                        self.response_headers = []
                finally:
                    exc_info = None

            self.complete = True

            if not status.__class__ is str:
                raise AssertionError("status %s is not a string" % status)
            if "\n" in status or "\r" in status:
                raise ValueError(
                    "carriage return/line feed character present in status"
                )

            self.status = status

            # Prepare the headers for output
            for k, v in headers:
                if not k.__class__ is str:
                    raise AssertionError(
                        "Header name %r is not a string in %r" % (k, (k, v))
                    )
                if not v.__class__ is str:
                    raise AssertionError(
                        "Header value %r is not a string in %r" % (v, (k, v))
                    )

                if "\n" in v or "\r" in v:
                    raise ValueError(
                        "carriage return/line feed character present in header value"
                    )
                if "\n" in k or "\r" in k:
                    raise ValueError(
                        "carriage return/line feed character present in header name"
                    )

                kl = k.lower()
                if kl == "content-length":
                    self.content_length = int(v)
                elif kl in hop_by_hop:
                    raise AssertionError(
                        '%s is a "hop-by-hop" header; it cannot be used by '
                        "a WSGI application (see PEP 3333)" % k
                    )

            self.response_headers.extend(headers)

            # Return a method used to write the response data.
            return self.write

        # Call the application to handle the request and write a response
        app_iter = self.channel.server.application(environ, start_response)

        can_close_app_iter = True
        try:
            if app_iter.__class__ is ReadOnlyFileBasedBuffer:
                cl = self.content_length
                size = app_iter.prepare(cl)
                if size:
                    if cl != size:
                        if cl is not None:
                            self.remove_content_length_header()
                        self.content_length = size
                    self.write(b"")  # generate headers
                    # if the write_soon below succeeds then the channel will
                    # take over closing the underlying file via the channel's
                    # _flush_some or handle_close so we intentionally avoid
                    # calling close in the finally block
                    self.channel.write_soon(app_iter)
                    can_close_app_iter = False
                    return

            first_chunk_len = None
            for chunk in app_iter:
                if first_chunk_len is None:
                    first_chunk_len = len(chunk)
                    # Set a Content-Length header if one is not supplied.
                    # start_response may not have been called until first
                    # iteration as per PEP, so we must reinterrogate
                    # self.content_length here
                    if self.content_length is None:
                        app_iter_len = None
                        if hasattr(app_iter, "__len__"):
                            app_iter_len = len(app_iter)
                        if app_iter_len == 1:
                            self.content_length = first_chunk_len
                # transmit headers only after first iteration of the iterable
                # that returns a non-empty bytestring (PEP 3333)
                if chunk:
                    self.write(chunk)

            cl = self.content_length
            if cl is not None:
                if self.content_bytes_written != cl:
                    # close the connection so the client isn't sitting around
                    # waiting for more data when there are too few bytes
                    # to service content-length
                    self.close_on_finish = True
                    if self.request.command != "HEAD":
                        self.logger.warning(
                            "application returned too few bytes (%s) "
                            "for specified Content-Length (%s) via app_iter"
                            % (self.content_bytes_written, cl),
                        )
        finally:
            if can_close_app_iter and hasattr(app_iter, "close"):
                app_iter.close()

    def get_environment(self):
        """Returns a WSGI environment."""
        environ = self.environ
        if environ is not None:
            # Return the cached copy.
            return environ

        request = self.request
        path = request.path
        channel = self.channel
        server = channel.server
        url_prefix = server.adj.url_prefix

        if path.startswith("/"):
            # strip extra slashes at the beginning of a path that starts
            # with any number of slashes
            path = "/" + path.lstrip("/")

        if url_prefix:
            # NB: url_prefix is guaranteed by the configuration machinery to
            # be either the empty string or a string that starts with a single
            # slash and ends without any slashes
            if path == url_prefix:
                # if the path is the same as the url prefix, the SCRIPT_NAME
                # should be the url_prefix and PATH_INFO should be empty
                path = ""
            else:
                # if the path starts with the url prefix plus a slash,
                # the SCRIPT_NAME should be the url_prefix and PATH_INFO should
                # the value of path from the slash until its end
                url_prefix_with_trailing_slash = url_prefix + "/"
                if path.startswith(url_prefix_with_trailing_slash):
                    path = path[len(url_prefix) :]

        environ = {
            "REMOTE_ADDR": channel.addr[0],
            # Nah, we aren't actually going to look up the reverse DNS for
            # REMOTE_ADDR, but we will happily set this environment variable
            # for the WSGI application. Spec says we can just set this to
            # REMOTE_ADDR, so we do.
            "REMOTE_HOST": channel.addr[0],
            # try and set the REMOTE_PORT to something useful, but maybe None
            "REMOTE_PORT": str(channel.addr[1]),
            "REQUEST_METHOD": request.command.upper(),
            "SERVER_PORT": str(server.effective_port),
            "SERVER_NAME": server.server_name,
            "SERVER_SOFTWARE": server.adj.ident,
            "SERVER_PROTOCOL": "HTTP/%s" % self.version,
            "SCRIPT_NAME": url_prefix,
            "PATH_INFO": path,
            "QUERY_STRING": request.query,
            "wsgi.url_scheme": request.url_scheme,
            # the following environment variables are required by the WSGI spec
            "wsgi.version": (1, 0),
            # apps should use the logging module
            "wsgi.errors": sys.stderr,
            "wsgi.multithread": True,
            "wsgi.multiprocess": False,
            "wsgi.run_once": False,
            "wsgi.input": request.get_body_stream(),
            "wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
            "wsgi.input_terminated": True,  # wsgi.input is EOF terminated
        }

        for key, value in dict(request.headers).items():
            value = value.strip()
            mykey = rename_headers.get(key, None)
            if mykey is None:
                mykey = "HTTP_" + key
            if mykey not in environ:
                environ[mykey] = value

        # Insert a callable into the environment that allows the application to
        # check if the client disconnected. Only works with
        # channel_request_lookahead larger than 0.
        environ["waitress.client_disconnected"] = self.channel.check_client_disconnected

        # cache the environ for this request
        self.environ = environ
        return environ
