← Back to Research

Overview

Federated learning across MTAs for collaborative spam detection without sharing private email data.

Problem Statement

Current spam filtering: - Each MTA learns independently (isolated data) - Cannot share training data (privacy concerns) - Spammers exploit small providers (lack of data) - No network effect - Reactive (not predictive)

Vision

Federated Anti-Spam Network

msgs.global MTA ──┐
                  ├──> Federated Learning ──> Shared Model
Other MTAs     ───┤     (privacy-preserving)    (benefits all)
Small Providers ──┘

• Train on local data (privacy preserved)
• Share model updates (not raw data)
• Global spam patterns detected
• Small providers benefit from network

Architecture

1. Federated Learning Model

Each MTA:
  1. Trains spam classifier on local email data
  2. Sends model updates (gradients) to coordinator
  3. Coordinator aggregates updates from all MTAs
  4. Distributes improved global model back to MTAs

Privacy: Raw email data never leaves the MTA

2. Spam Features

def extract_features(message):
    """Extract spam detection features"""
    return {
        # Content features
        'subject_length': len(message.subject),
        'body_length': len(message.body),
        'has_links': count_links(message.body),
        'link_domains': extract_domains(message.body),
        'has_attachments': len(message.attachments) > 0,
        'attachment_types': [a.mime_type for a in message.attachments],

        # Sender features
        'sender_domain_age': get_domain_age(message.from_domain),
        'sender_reputation': get_sender_reputation(message.from_email),
        'spf_result': message.spf_result,
        'dkim_result': message.dkim_result,
        'dmarc_result': message.dmarc_result,

        # Behavioral features
        'time_of_day': message.sent_at.hour,
        'day_of_week': message.sent_at.weekday(),
        'recipient_count': len(message.to),
        'bcc_count': len(message.bcc),

        # Linguistic features
        'spam_words': count_spam_keywords(message),
        'urgency_score': detect_urgency(message),
        'financial_terms': contains_financial_keywords(message),
        'sentiment': analyze_sentiment(message.body),

        # Metadata features
        'headers_suspicious': check_suspicious_headers(message),
        'ip_reputation': get_ip_reputation(message.originating_ip),
        'geographic_origin': geolocate(message.originating_ip)
    }

3. Local Model Training

import tensorflow as tf
from tensorflow_federated as tff

class SpamClassifier:
    """Local spam classifier for federated learning"""

    def __init__(self):
        self.model = self.build_model()

    def build_model(self):
        """Neural network for spam classification"""
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(128, activation='relu', input_shape=(100,)),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dropout(0.3),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(1, activation='sigmoid')  # Spam probability
        ])

        model.compile(
            optimizer='adam',
            loss='binary_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )

        return model

    def train_local(self, messages, labels):
        """Train on local data"""
        # Extract features
        X = [extract_features(msg) for msg in messages]

        # Train model
        self.model.fit(X, labels, epochs=10, validation_split=0.2)

        # Return model updates (weights)
        return self.model.get_weights()

    def update_from_global(self, global_weights):
        """Update local model from global aggregation"""
        self.model.set_weights(global_weights)

    def predict(self, message):
        """Classify message as spam or ham"""
        features = extract_features(message)
        spam_probability = self.model.predict([features])[0][0]

        return {
            'is_spam': spam_probability > 0.5,
            'spam_score': spam_probability,
            'confidence': abs(spam_probability - 0.5) * 2
        }

4. Federated Coordination Server

class FederatedSpamCoordinator:
    """Coordinate federated learning across MTAs"""

    def __init__(self):
        self.global_model = SpamClassifier().model
        self.participating_mtas = []

    def register_mta(self, mta_id, public_key):
        """Register MTA for federated learning"""
        self.participating_mtas.append({
            'id': mta_id,
            'public_key': public_key,
            'last_update': None
        })

    def receive_model_update(self, mta_id, encrypted_weights):
        """Receive model update from MTA"""
        # Verify signature
        if not verify_mta_signature(mta_id, encrypted_weights):
            return {'error': 'Invalid signature'}

        # Decrypt weights (homomorphic encryption)
        weights = decrypt_weights(encrypted_weights)

        # Store update
        self.store_update(mta_id, weights)

        # If enough MTAs have submitted, aggregate
        if self.ready_to_aggregate():
            self.aggregate_and_distribute()

    def aggregate_and_distribute(self):
        """Federated averaging of model updates"""
        # Get all pending updates
        updates = self.get_pending_updates()

        # Federated averaging
        aggregated_weights = self.federated_average(updates)

        # Update global model
        self.global_model.set_weights(aggregated_weights)

        # Distribute to all MTAs
        for mta in self.participating_mtas:
            self.send_global_model(mta['id'], aggregated_weights)

    def federated_average(self, updates):
        """Average model weights across MTAs"""
        # Weight by number of training examples
        total_examples = sum(u['num_examples'] for u in updates)

        averaged_weights = []
        for layer_idx in range(len(updates[0]['weights'])):
            layer_sum = np.zeros_like(updates[0]['weights'][layer_idx])

            for update in updates:
                weight = update['num_examples'] / total_examples
                layer_sum += update['weights'][layer_idx] * weight

            averaged_weights.append(layer_sum)

        return averaged_weights

Integration with msgs.global

Postfix Integration

# /etc/postfix/main.cf
smtpd_recipient_restrictions =
    ...
    check_policy_service inet:127.0.0.1:10050  # ML spam filter
    ...

Policy Service: postfix-smtp/ml-spam-filter.py

#!/usr/bin/env python3
"""ML-based spam filter for Postfix"""

import socket
from spam_classifier import SpamClassifier

classifier = SpamClassifier()
classifier.load_global_model()  # Load from federated coordinator

def handle_policy_request(conn):
    """Postfix policy protocol"""
    request = parse_postfix_request(conn)

    # Build message from request
    message = {
        'from': request['sender'],
        'to': request['recipient'],
        'subject': request.get('subject'),
        'body': request.get('body'),
        'spf_result': request.get('spf'),
        'dkim_result': request.get('dkim')
    }

    # Classify
    result = classifier.predict(message)

    if result['is_spam'] and result['confidence'] > 0.8:
        response = f"action=REJECT Spam detected (score: {result['spam_score']:.2f})\n\n"
    elif result['is_spam']:
        response = f"action=PREPEND X-Spam-Score: {result['spam_score']:.2f}\n\n"
    else:
        response = "action=DUNNO\n\n"

    conn.send(response.encode())
    conn.close()

def main():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(('127.0.0.1', 10050))
    sock.listen(5)

    while True:
        conn, addr = sock.accept()
        handle_policy_request(conn)

if __name__ == '__main__':
    main()

Training Pipeline

# Background job: Train on user feedback
async def train_on_feedback():
    """Train classifier on spam/ham labels from users"""
    while True:
        # Get user feedback (spam button clicks, ham corrections)
        feedback = db.query("""
            SELECT message_id, is_spam
            FROM user_spam_feedback
            WHERE processed = FALSE
            LIMIT 1000
        """)

        if feedback:
            # Load messages
            messages = [load_message(f.message_id) for f in feedback]
            labels = [f.is_spam for f in feedback]

            # Train local model
            weights_update = classifier.train_local(messages, labels)

            # Send to federated coordinator
            send_to_coordinator(weights_update, num_examples=len(messages))

            # Mark as processed
            db.execute("""
                UPDATE user_spam_feedback
                SET processed = TRUE
                WHERE message_id IN (?)
            """, [f.message_id for f in feedback])

        # Wait before next training round
        await asyncio.sleep(3600)  # Train hourly

API Endpoints

@app.route('/api/v1/spam/classify', methods=['POST'])
def classify_message_api():
    """Classify message via API"""
    message_data = request.json

    result = classifier.predict(message_data)

    return jsonify(result)

@app.route('/api/v1/spam/feedback', methods=['POST'])
def submit_spam_feedback():
    """User marks message as spam/ham"""
    data = request.json

    db.execute("""
        INSERT INTO user_spam_feedback (
            user_id, message_id, is_spam, feedback_type
        ) VALUES (?, ?, ?, ?)
    """, request.current_user.id, data['message_id'],
         data['is_spam'], data.get('type', 'user_report'))

    # Trigger retraining
    schedule_training()

    return {'status': 'feedback_recorded'}

@app.route('/api/v1/spam/stats')
def get_spam_stats():
    """Get spam filtering statistics"""
    return jsonify({
        'messages_classified': get_total_classified(),
        'spam_blocked': get_spam_blocked_count(),
        'accuracy': get_model_accuracy(),
        'false_positive_rate': get_false_positive_rate(),
        'model_version': classifier.model_version,
        'last_training': classifier.last_trained_at
    })

Privacy Guarantees

Federated Learning Ensures:

  1. No raw data sharing: Only model gradients shared
  2. Differential privacy: Noise added to gradients
  3. Secure aggregation: Encrypted weight updates
  4. Verifiable updates: Cryptographic signatures

Example

msgs.global trains on 10,000 emails (never leaves server)
  ↓
Computes model update (gradient)
  ↓
Adds differential privacy noise
  ↓
Encrypts update
  ↓
Sends to coordinator → [encrypted gradient, not emails]

Performance

Expected Metrics

  • Accuracy: 99%+ (with federated training)
  • False positive rate: <0.1%
  • Latency: <10ms per message
  • Spam block rate: 98%+

Comparison

Method Accuracy False Positives Data Sharing
Standalone ML 95% 1% No data
Federated ML 99% 0.1% Model only
Cloud ML (centralized) 99.5% 0.05% All data ⚠️

Federated Network

Participating MTAs

msgs.global
├── Contributes: 1M messages/month
├── Benefits: Global spam patterns
└── Privacy: Zero data sharing

Small Provider (example.com)
├── Contributes: 10K messages/month
├── Benefits: Enterprise-grade filtering
└── Privacy: Zero data sharing

Network Effect: 100 MTAs × 100K messages = 10M training examples

Status

🔬 Research & Prototyping Phase

Next Steps

  1. [ ] Implement basic spam classifier (standalone)
  2. [ ] Set up federated learning framework (TensorFlow Federated)
  3. [ ] Build coordination server
  4. [ ] Integrate with Postfix policy service
  5. [ ] User feedback collection UI
  6. [ ] Privacy audit (differential privacy params)
  7. [ ] Recruit partner MTAs for federated network

Related Technologies

  • TensorFlow Federated: Google's federated learning framework
  • PySyft: Privacy-preserving ML
  • Flower: Federated learning framework
  • SpamAssassin: Traditional spam filtering (baseline)