mirror of
https://github.com/mattermost/mattermost.git
synced 2026-02-03 20:40:00 -05:00
Go tests: - TestInvalidStatusCode_TooLow: verify status 42 returns 500 error - TestInvalidStatusCode_TooHigh: verify status 1000 returns 500 error - TestInvalidStatusCode_DoesNotPanic: ensure no panic for any invalid code - TestFlush_WithFlusher: verify flush calls http.Flusher.Flush() - TestFlush_WithoutFlusher: verify flush doesn't panic without Flusher - TestFlushGracefulDegradation: mirror plugin/http_test.go behavior - TestStreamResponseBody_WithFlush: verify flush flags processed in stream - TestEarlyResponse_CancelsRequestSend: verify early response stops sender Python tests: - test_flush_method_exists: verify flush() is callable - test_flush_applies_to_last_write: verify retroactive flush - test_flush_applies_to_next_write_when_empty: verify pre-write flush - test_multiple_writes_with_selective_flush: verify selective flushing - test_streaming_response_chunks: verify multiple writes = multiple messages - test_empty_body_response: verify no-body response handling - test_response_headers_in_first_message: verify header placement Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
845 lines
27 KiB
Python
845 lines
27 KiB
Python
# Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
|
# See LICENSE.txt for license information.
|
|
|
|
"""
|
|
Tests for ServeHTTP streaming hook implementation.
|
|
|
|
This module tests the Python-side ServeHTTP handler including:
|
|
- HTTPRequest wrapper class
|
|
- HTTPResponseWriter class
|
|
- Header conversion utilities
|
|
- Request body assembly from chunks
|
|
- Response streaming behavior
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from mattermost_plugin import Plugin, hook, HookName
|
|
from mattermost_plugin.servicers.hooks_servicer import (
|
|
PluginHooksServicerImpl,
|
|
HTTPRequest,
|
|
HTTPResponseWriter,
|
|
_convert_headers_to_dict,
|
|
_convert_dict_to_headers,
|
|
)
|
|
from mattermost_plugin.grpc import hooks_http_pb2
|
|
from mattermost_plugin.grpc import hooks_common_pb2
|
|
|
|
|
|
# =============================================================================
|
|
# Helper functions
|
|
# =============================================================================
|
|
|
|
|
|
def make_plugin_context() -> hooks_common_pb2.PluginContext:
|
|
"""Create a test PluginContext."""
|
|
return hooks_common_pb2.PluginContext(
|
|
session_id="session123",
|
|
request_id="request123",
|
|
)
|
|
|
|
|
|
def make_request_init(
|
|
method: str = "GET",
|
|
url: str = "/plugins/test/api/hello",
|
|
headers: list = None,
|
|
) -> hooks_http_pb2.ServeHTTPRequestInit:
|
|
"""Create a ServeHTTPRequestInit for testing."""
|
|
if headers is None:
|
|
headers = []
|
|
return hooks_http_pb2.ServeHTTPRequestInit(
|
|
plugin_context=make_plugin_context(),
|
|
method=method,
|
|
url=url,
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers=headers,
|
|
host="localhost:8065",
|
|
remote_addr="127.0.0.1:12345",
|
|
request_uri=url,
|
|
content_length=-1,
|
|
)
|
|
|
|
|
|
async def collect_responses(async_gen):
|
|
"""Collect all responses from an async generator."""
|
|
responses = []
|
|
async for resp in async_gen:
|
|
responses.append(resp)
|
|
return responses
|
|
|
|
|
|
# =============================================================================
|
|
# HTTPRequest Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestHTTPRequest:
|
|
"""Tests for HTTPRequest wrapper class."""
|
|
|
|
def test_basic_properties(self):
|
|
"""Test basic property access."""
|
|
req = HTTPRequest(
|
|
method="POST",
|
|
url="http://localhost/api/v1/test",
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers={"Content-Type": ["application/json"], "Accept": ["text/html"]},
|
|
host="localhost:8065",
|
|
remote_addr="192.168.1.1:54321",
|
|
request_uri="/api/v1/test",
|
|
content_length=100,
|
|
plugin_context=None,
|
|
)
|
|
|
|
assert req.method == "POST"
|
|
assert req.url == "http://localhost/api/v1/test"
|
|
assert req.host == "localhost:8065"
|
|
assert req.content_length == 100
|
|
|
|
def test_get_header_case_insensitive(self):
|
|
"""Test case-insensitive header lookup."""
|
|
req = HTTPRequest(
|
|
method="GET",
|
|
url="/test",
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers={"Content-Type": ["application/json"], "X-Custom-Header": ["value1"]},
|
|
host="localhost",
|
|
remote_addr="127.0.0.1",
|
|
request_uri="/test",
|
|
content_length=0,
|
|
)
|
|
|
|
# Exact case
|
|
assert req.get_header("Content-Type") == "application/json"
|
|
# Lower case
|
|
assert req.get_header("content-type") == "application/json"
|
|
# Upper case
|
|
assert req.get_header("CONTENT-TYPE") == "application/json"
|
|
# Mixed case
|
|
assert req.get_header("x-custom-header") == "value1"
|
|
|
|
def test_get_header_default(self):
|
|
"""Test default value for missing headers."""
|
|
req = HTTPRequest(
|
|
method="GET",
|
|
url="/test",
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers={},
|
|
host="localhost",
|
|
remote_addr="127.0.0.1",
|
|
request_uri="/test",
|
|
content_length=0,
|
|
)
|
|
|
|
assert req.get_header("X-Missing") == ""
|
|
assert req.get_header("X-Missing", "default") == "default"
|
|
|
|
def test_get_all_headers(self):
|
|
"""Test getting all values for a multi-value header."""
|
|
req = HTTPRequest(
|
|
method="GET",
|
|
url="/test",
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers={"Set-Cookie": ["a=1", "b=2", "c=3"]},
|
|
host="localhost",
|
|
remote_addr="127.0.0.1",
|
|
request_uri="/test",
|
|
content_length=0,
|
|
)
|
|
|
|
cookies = req.get_all_headers("Set-Cookie")
|
|
assert len(cookies) == 3
|
|
assert "a=1" in cookies
|
|
assert "b=2" in cookies
|
|
|
|
def test_body_attribute(self):
|
|
"""Test body attribute."""
|
|
req = HTTPRequest(
|
|
method="POST",
|
|
url="/test",
|
|
proto="HTTP/1.1",
|
|
proto_major=1,
|
|
proto_minor=1,
|
|
headers={},
|
|
host="localhost",
|
|
remote_addr="127.0.0.1",
|
|
request_uri="/test",
|
|
content_length=11,
|
|
)
|
|
|
|
req.body = b"hello world"
|
|
assert req.body == b"hello world"
|
|
|
|
|
|
# =============================================================================
|
|
# HTTPResponseWriter Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestHTTPResponseWriter:
|
|
"""Tests for HTTPResponseWriter class."""
|
|
|
|
def test_default_status_code(self):
|
|
"""Test that default status code is 200 after write."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"hello")
|
|
|
|
assert w.status_code == 200
|
|
|
|
def test_explicit_status_code(self):
|
|
"""Test setting explicit status code."""
|
|
w = HTTPResponseWriter()
|
|
w.write_header(404)
|
|
w.write(b"Not Found")
|
|
|
|
assert w.status_code == 404
|
|
|
|
def test_set_header(self):
|
|
"""Test setting a header."""
|
|
w = HTTPResponseWriter()
|
|
w.set_header("Content-Type", "application/json")
|
|
|
|
assert w.headers["Content-Type"] == ["application/json"]
|
|
|
|
def test_add_header_multi_value(self):
|
|
"""Test adding multiple values to a header."""
|
|
w = HTTPResponseWriter()
|
|
w.add_header("Set-Cookie", "a=1")
|
|
w.add_header("Set-Cookie", "b=2")
|
|
|
|
assert w.headers["Set-Cookie"] == ["a=1", "b=2"]
|
|
|
|
def test_write_string_auto_encode(self):
|
|
"""Test that writing a string auto-encodes to bytes."""
|
|
w = HTTPResponseWriter()
|
|
w.write("hello")
|
|
|
|
assert w.get_body() == b"hello"
|
|
|
|
def test_write_multiple_chunks(self):
|
|
"""Test writing multiple chunks."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"hello ")
|
|
w.write(b"world")
|
|
|
|
assert w.get_body() == b"hello world"
|
|
|
|
def test_header_warning_after_write(self):
|
|
"""Test that setting headers after write logs a warning."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"body")
|
|
|
|
# Should log warning but not raise
|
|
w.set_header("X-Late", "value")
|
|
|
|
# Header should not be set
|
|
assert "X-Late" not in w.headers
|
|
|
|
# ==========================================================================
|
|
# Flush Tests (Phase 8.2)
|
|
# ==========================================================================
|
|
|
|
def test_flush_method_exists(self):
|
|
"""Test that flush method exists and is callable."""
|
|
w = HTTPResponseWriter()
|
|
# Should not raise
|
|
w.flush()
|
|
|
|
def test_flush_applies_to_last_write(self):
|
|
"""Test that flush applies to the last pending write."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"chunk1")
|
|
w.write(b"chunk2")
|
|
w.flush()
|
|
|
|
pending = w.get_pending_writes()
|
|
assert len(pending) == 2
|
|
# First write should not have flush
|
|
assert pending[0][1] is False
|
|
# Second (last) write should have flush
|
|
assert pending[1][1] is True
|
|
|
|
def test_flush_applies_to_next_write_when_empty(self):
|
|
"""Test that flush before any write applies to next write."""
|
|
w = HTTPResponseWriter()
|
|
w.flush() # Flush before any write
|
|
w.write(b"chunk1")
|
|
|
|
pending = w.get_pending_writes()
|
|
assert len(pending) == 1
|
|
# The write should have flush=True since flush was called before
|
|
assert pending[0][1] is True
|
|
|
|
def test_multiple_writes_with_selective_flush(self):
|
|
"""Test multiple writes with selective flushing."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"chunk1")
|
|
w.flush()
|
|
w.write(b"chunk2")
|
|
w.write(b"chunk3")
|
|
w.flush()
|
|
|
|
pending = w.get_pending_writes()
|
|
assert len(pending) == 3
|
|
|
|
# chunk1 should have flush (applied retroactively)
|
|
assert pending[0] == (b"chunk1", True)
|
|
# chunk2 should NOT have flush
|
|
assert pending[1] == (b"chunk2", False)
|
|
# chunk3 should have flush (applied retroactively)
|
|
assert pending[2] == (b"chunk3", True)
|
|
|
|
def test_get_pending_writes(self):
|
|
"""Test get_pending_writes returns all writes with flush flags."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"a")
|
|
w.write(b"b")
|
|
w.write(b"c")
|
|
|
|
pending = w.get_pending_writes()
|
|
assert len(pending) == 3
|
|
assert pending[0] == (b"a", False)
|
|
assert pending[1] == (b"b", False)
|
|
assert pending[2] == (b"c", False)
|
|
|
|
def test_clear_pending_writes(self):
|
|
"""Test clear_pending_writes clears the pending writes list."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"chunk")
|
|
assert len(w.get_pending_writes()) == 1
|
|
|
|
w.clear_pending_writes()
|
|
assert len(w.get_pending_writes()) == 0
|
|
|
|
def test_max_chunk_size_constant(self):
|
|
"""Test that MAX_CHUNK_SIZE constant is defined."""
|
|
assert HTTPResponseWriter.MAX_CHUNK_SIZE == 64 * 1024
|
|
|
|
def test_get_body_still_works(self):
|
|
"""Test that get_body still returns full body."""
|
|
w = HTTPResponseWriter()
|
|
w.write(b"hello ")
|
|
w.flush()
|
|
w.write(b"world")
|
|
|
|
# get_body should return full concatenated body
|
|
assert w.get_body() == b"hello world"
|
|
|
|
|
|
# =============================================================================
|
|
# Header Conversion Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestHeaderConversion:
|
|
"""Tests for header conversion utilities."""
|
|
|
|
def test_proto_to_dict_empty(self):
|
|
"""Test converting empty headers."""
|
|
result = _convert_headers_to_dict([])
|
|
assert result == {}
|
|
|
|
def test_proto_to_dict_single_value(self):
|
|
"""Test converting single-value headers."""
|
|
headers = [
|
|
hooks_http_pb2.HTTPHeader(key="Content-Type", values=["application/json"]),
|
|
]
|
|
result = _convert_headers_to_dict(headers)
|
|
|
|
assert result == {"Content-Type": ["application/json"]}
|
|
|
|
def test_proto_to_dict_multi_value(self):
|
|
"""Test converting multi-value headers."""
|
|
headers = [
|
|
hooks_http_pb2.HTTPHeader(key="Accept", values=["text/html", "application/json"]),
|
|
]
|
|
result = _convert_headers_to_dict(headers)
|
|
|
|
assert result == {"Accept": ["text/html", "application/json"]}
|
|
|
|
def test_dict_to_proto(self):
|
|
"""Test converting dict to proto headers."""
|
|
headers = {
|
|
"Content-Type": ["application/json"],
|
|
"X-Custom": ["val1", "val2"],
|
|
}
|
|
result = _convert_dict_to_headers(headers)
|
|
|
|
assert len(result) == 2
|
|
# Find Content-Type header
|
|
ct = next(h for h in result if h.key == "Content-Type")
|
|
assert list(ct.values) == ["application/json"]
|
|
|
|
def test_dict_to_proto_string_value(self):
|
|
"""Test that string values are wrapped in list."""
|
|
headers = {"Content-Type": "text/plain"}
|
|
result = _convert_dict_to_headers(headers)
|
|
|
|
assert len(result) == 1
|
|
assert list(result[0].values) == ["text/plain"]
|
|
|
|
|
|
# =============================================================================
|
|
# ServeHTTP Servicer Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestServeHTTPServicer:
|
|
"""Tests for ServeHTTP gRPC servicer method."""
|
|
|
|
@pytest.fixture
|
|
def simple_plugin(self):
|
|
"""Create a plugin with a simple ServeHTTP handler."""
|
|
|
|
class SimplePlugin(Plugin):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.received_requests = []
|
|
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
self.received_requests.append({
|
|
"method": r.method,
|
|
"url": r.url,
|
|
"body": r.body,
|
|
})
|
|
w.set_header("Content-Type", "text/plain")
|
|
w.write_header(200)
|
|
w.write(f"Hello! Method: {r.method}")
|
|
|
|
return SimplePlugin()
|
|
|
|
@pytest.fixture
|
|
def error_plugin(self):
|
|
"""Create a plugin that raises an error."""
|
|
|
|
class ErrorPlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
raise ValueError("Test error")
|
|
|
|
return ErrorPlugin()
|
|
|
|
@pytest.fixture
|
|
def no_handler_plugin(self):
|
|
"""Create a plugin without ServeHTTP handler."""
|
|
|
|
class NoHandlerPlugin(Plugin):
|
|
pass
|
|
|
|
return NoHandlerPlugin()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simple_get_request(self, simple_plugin):
|
|
"""Test handling a simple GET request."""
|
|
servicer = PluginHooksServicerImpl(simple_plugin)
|
|
|
|
# Create request stream (single message with init, no body)
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="GET", url="/hello"),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
assert len(responses) >= 1
|
|
assert responses[0].init.status_code == 200
|
|
assert b"Hello! Method: GET" in responses[0].body_chunk
|
|
assert responses[-1].body_complete is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_post_with_body(self, simple_plugin):
|
|
"""Test handling a POST request with body."""
|
|
servicer = PluginHooksServicerImpl(simple_plugin)
|
|
|
|
# Create request stream with body in multiple chunks
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="POST", url="/api/create"),
|
|
body_chunk=b"hello ",
|
|
body_complete=False,
|
|
)
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
body_chunk=b"world",
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Verify request was received with assembled body
|
|
assert len(simple_plugin.received_requests) == 1
|
|
assert simple_plugin.received_requests[0]["body"] == b"hello world"
|
|
assert simple_plugin.received_requests[0]["method"] == "POST"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handler_error_returns_500(self, error_plugin):
|
|
"""Test that handler errors return 500 status."""
|
|
servicer = PluginHooksServicerImpl(error_plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
assert len(responses) >= 1
|
|
assert responses[0].init.status_code == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_handler_returns_404(self, no_handler_plugin):
|
|
"""Test that missing handler returns 404."""
|
|
servicer = PluginHooksServicerImpl(no_handler_plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
assert len(responses) >= 1
|
|
assert responses[0].init.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_request_stream(self):
|
|
"""Test handling empty request stream."""
|
|
|
|
class DummyPlugin(Plugin):
|
|
pass
|
|
|
|
servicer = PluginHooksServicerImpl(DummyPlugin())
|
|
|
|
async def empty_stream():
|
|
# Empty iterator
|
|
return
|
|
yield # Make it a generator
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(empty_stream(), context))
|
|
|
|
# Should return error response
|
|
assert len(responses) >= 1
|
|
assert responses[0].init.status_code == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_request_headers_passed_to_handler(self, simple_plugin):
|
|
"""Test that request headers are properly passed."""
|
|
servicer = PluginHooksServicerImpl(simple_plugin)
|
|
|
|
headers = [
|
|
hooks_http_pb2.HTTPHeader(key="Content-Type", values=["application/json"]),
|
|
hooks_http_pb2.HTTPHeader(key="Authorization", values=["Bearer token123"]),
|
|
]
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(headers=headers),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Plugin should have received the request
|
|
assert len(simple_plugin.received_requests) == 1
|
|
|
|
|
|
# =============================================================================
|
|
# Chunking Behavior Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestChunkingBehavior:
|
|
"""Tests for request body chunking behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_body_assembly(self):
|
|
"""Test that large bodies are correctly assembled from chunks."""
|
|
|
|
class BodyCapturingPlugin(Plugin):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.captured_body = None
|
|
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
self.captured_body = r.body
|
|
w.write(b"OK")
|
|
|
|
plugin = BodyCapturingPlugin()
|
|
servicer = PluginHooksServicerImpl(plugin)
|
|
|
|
# Simulate chunked body (3 chunks of 100 bytes each)
|
|
chunk1 = b"a" * 100
|
|
chunk2 = b"b" * 100
|
|
chunk3 = b"c" * 50
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="POST"),
|
|
body_chunk=chunk1,
|
|
body_complete=False,
|
|
)
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
body_chunk=chunk2,
|
|
body_complete=False,
|
|
)
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
body_chunk=chunk3,
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Verify body was correctly assembled
|
|
assert plugin.captured_body == chunk1 + chunk2 + chunk3
|
|
assert len(plugin.captured_body) == 250
|
|
|
|
|
|
# =============================================================================
|
|
# Cancellation Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestCancellation:
|
|
"""Tests for request cancellation handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_during_body_read(self):
|
|
"""Test that cancellation is detected during body streaming."""
|
|
|
|
class SlowPlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
w.write(b"Should not reach here")
|
|
|
|
plugin = SlowPlugin()
|
|
servicer = PluginHooksServicerImpl(plugin)
|
|
|
|
# Context that reports cancelled after first chunk
|
|
context = MagicMock()
|
|
call_count = [0]
|
|
|
|
def check_cancelled():
|
|
call_count[0] += 1
|
|
return call_count[0] > 1 # Cancel after first check
|
|
|
|
context.cancelled = check_cancelled
|
|
|
|
async def slow_request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="POST"),
|
|
body_chunk=b"first chunk",
|
|
body_complete=False,
|
|
)
|
|
# Simulate delay before next chunk
|
|
await asyncio.sleep(0.01)
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
body_chunk=b"second chunk",
|
|
body_complete=True,
|
|
)
|
|
|
|
# The servicer should detect cancellation
|
|
responses = await collect_responses(servicer.ServeHTTP(slow_request_stream(), context))
|
|
|
|
# Should have no responses if cancelled early, or error response
|
|
# The exact behavior depends on when cancellation is detected
|
|
# For now, just verify no exception is raised
|
|
|
|
|
|
# =============================================================================
|
|
# Response Streaming Tests (Phase 8.2)
|
|
# =============================================================================
|
|
|
|
|
|
class TestResponseStreaming:
|
|
"""Tests for response streaming with flush support."""
|
|
|
|
@pytest.fixture
|
|
def streaming_plugin(self):
|
|
"""Create a plugin that writes multiple chunks with flush."""
|
|
|
|
class StreamingPlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
w.set_header("Content-Type", "text/event-stream")
|
|
w.write(b"chunk1")
|
|
w.flush()
|
|
w.write(b"chunk2")
|
|
w.write(b"chunk3")
|
|
w.flush()
|
|
|
|
return StreamingPlugin()
|
|
|
|
@pytest.fixture
|
|
def empty_body_plugin(self):
|
|
"""Create a plugin that writes no body."""
|
|
|
|
class EmptyBodyPlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
w.set_header("Content-Type", "text/plain")
|
|
w.write_header(204) # No Content
|
|
|
|
return EmptyBodyPlugin()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_response_chunks(self, streaming_plugin):
|
|
"""Test that multiple write() calls result in multiple response messages."""
|
|
servicer = PluginHooksServicerImpl(streaming_plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="GET", url="/stream"),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Should have 3 response messages (one per write)
|
|
assert len(responses) == 3
|
|
|
|
# First message should have init + first chunk + flush
|
|
assert responses[0].init is not None
|
|
assert responses[0].init.status_code == 200
|
|
assert responses[0].body_chunk == b"chunk1"
|
|
assert responses[0].flush is True
|
|
assert responses[0].body_complete is False
|
|
|
|
# Second message should have chunk2 + no flush
|
|
assert responses[1].body_chunk == b"chunk2"
|
|
assert responses[1].flush is False
|
|
assert responses[1].body_complete is False
|
|
|
|
# Third message should have chunk3 + flush + complete
|
|
assert responses[2].body_chunk == b"chunk3"
|
|
assert responses[2].flush is True
|
|
assert responses[2].body_complete is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_body_response(self, empty_body_plugin):
|
|
"""Test response with no body writes."""
|
|
servicer = PluginHooksServicerImpl(empty_body_plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(method="GET"),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Should have exactly 1 response with init only
|
|
assert len(responses) == 1
|
|
assert responses[0].init is not None
|
|
assert responses[0].init.status_code == 204
|
|
assert responses[0].body_chunk == b""
|
|
assert responses[0].body_complete is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_response_headers_in_first_message(self):
|
|
"""Test that headers are only in the first response message."""
|
|
|
|
class HeaderPlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
w.set_header("Content-Type", "text/plain")
|
|
w.set_header("X-Custom", "value")
|
|
w.write(b"chunk1")
|
|
w.write(b"chunk2")
|
|
|
|
plugin = HeaderPlugin()
|
|
servicer = PluginHooksServicerImpl(plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# First message should have init with headers
|
|
assert responses[0].init is not None
|
|
header_keys = [h.key for h in responses[0].init.headers]
|
|
assert "Content-Type" in header_keys
|
|
assert "X-Custom" in header_keys
|
|
|
|
# Second message should NOT have init
|
|
if len(responses) > 1:
|
|
assert not responses[1].HasField("init") or responses[1].init is None or (
|
|
responses[1].init.status_code == 0 and len(responses[1].init.headers) == 0
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_write_single_response(self):
|
|
"""Test that a single write results in a single response."""
|
|
|
|
class SingleWritePlugin(Plugin):
|
|
@hook(HookName.ServeHTTP)
|
|
def serve_http(self, ctx, w, r):
|
|
w.set_header("Content-Type", "text/plain")
|
|
w.write(b"single response body")
|
|
|
|
plugin = SingleWritePlugin()
|
|
servicer = PluginHooksServicerImpl(plugin)
|
|
|
|
async def request_stream():
|
|
yield hooks_http_pb2.ServeHTTPRequest(
|
|
init=make_request_init(),
|
|
body_complete=True,
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.cancelled.return_value = False
|
|
|
|
responses = await collect_responses(servicer.ServeHTTP(request_stream(), context))
|
|
|
|
# Should have exactly 1 response
|
|
assert len(responses) == 1
|
|
assert responses[0].init.status_code == 200
|
|
assert responses[0].body_chunk == b"single response body"
|
|
assert responses[0].body_complete is True
|