WooCommerce Code Reference

QueryComplexity.php

Source code

<?php declare(strict_types=1);

namespace Automattic\WooCommerce\Vendor\GraphQL\Validator\Rules;

use Automattic\WooCommerce\Vendor\GraphQL\Error\Error;
use Automattic\WooCommerce\Vendor\GraphQL\Error\InvariantViolation;
use Automattic\WooCommerce\Vendor\GraphQL\Executor\Values;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\DocumentNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\FieldNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\FragmentSpreadNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\InlineFragmentNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\NodeKind;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\NodeList;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\OperationDefinitionNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\SelectionNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\SelectionSetNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\AST\VariableDefinitionNode;
use Automattic\WooCommerce\Vendor\GraphQL\Language\Visitor;
use Automattic\WooCommerce\Vendor\GraphQL\Language\VisitorOperation;
use Automattic\WooCommerce\Vendor\GraphQL\Type\Definition\Directive;
use Automattic\WooCommerce\Vendor\GraphQL\Type\Definition\FieldDefinition;
use Automattic\WooCommerce\Vendor\GraphQL\Type\Introspection;
use Automattic\WooCommerce\Vendor\GraphQL\Validator\QueryValidationContext;

/**
 * @phpstan-import-type ASTAndDefs from QuerySecurityRule
 */
class QueryComplexity extends QuerySecurityRule
{
    protected int $maxQueryComplexity;

    protected int $queryComplexity;

    /** @var array<string, mixed> */
    protected array $rawVariableValues = [];

    /** @var NodeList<VariableDefinitionNode> */
    protected NodeList $variableDefs;

    /** @phpstan-var ASTAndDefs */
    protected \ArrayObject $fieldNodeAndDefs;

    protected QueryValidationContext $context;

    /** @throws \InvalidArgumentException */
    public function __construct(int $maxQueryComplexity)
    {
        $this->setMaxQueryComplexity($maxQueryComplexity);
    }

    public function getVisitor(QueryValidationContext $context): array
    {
        $this->queryComplexity = 0;
        $this->context = $context;
        $this->variableDefs = new NodeList([]);
        $this->fieldNodeAndDefs = new \ArrayObject();

        return $this->invokeIfNeeded(
            $context,
            [
                NodeKind::SELECTION_SET => function (SelectionSetNode $selectionSet) use ($context): void {
                    $this->fieldNodeAndDefs = $this->collectFieldASTsAndDefs(
                        $context,
                        $context->getParentType(),
                        $selectionSet,
                        null,
                        $this->fieldNodeAndDefs
                    );
                },
                NodeKind::VARIABLE_DEFINITION => function ($def): VisitorOperation {
                    $this->variableDefs[] = $def;

                    return Visitor::skipNode();
                },
                NodeKind::DOCUMENT => [
                    'leave' => function (DocumentNode $document) use ($context): void {
                        $errors = $context->getErrors();

                        if ($errors !== []) {
                            return;
                        }

                        if ($this->maxQueryComplexity === self::DISABLED) {
                            return;
                        }

                        foreach ($document->definitions as $definition) {
                            if (! $definition instanceof OperationDefinitionNode) {
                                continue;
                            }

                            $this->queryComplexity = $this->fieldComplexity($definition->selectionSet);

                            if ($this->queryComplexity > $this->maxQueryComplexity) {
                                $context->reportError(
                                    new Error(static::maxQueryComplexityErrorMessage(
                                        $this->maxQueryComplexity,
                                        $this->queryComplexity
                                    ))
                                );

                                return;
                            }
                        }
                    },
                ],
            ]
        );
    }

    /** @throws \Exception */
    protected function fieldComplexity(SelectionSetNode $selectionSet): int
    {
        $complexity = 0;

        foreach ($selectionSet->selections as $selection) {
            $complexity += $this->nodeComplexity($selection);
        }

        return $complexity;
    }

    /** @throws \Exception */
    protected function nodeComplexity(SelectionNode $node): int
    {
        switch (true) {
            case $node instanceof FieldNode:
                // Exclude __schema field and all nested content from complexity calculation
                if ($node->name->value === Introspection::SCHEMA_FIELD_NAME) {
                    return 0;
                }

                if ($this->directiveExcludesField($node)) {
                    return 0;
                }

                $childrenComplexity = isset($node->selectionSet)
                    ? $this->fieldComplexity($node->selectionSet)
                    : 0;

                $fieldDef = $this->fieldDefinition($node);
                if ($fieldDef instanceof FieldDefinition && $fieldDef->complexityFn !== null) {
                    $fieldArguments = $this->buildFieldArguments($node);

                    return ($fieldDef->complexityFn)($childrenComplexity, $fieldArguments);
                }

                return $childrenComplexity + 1;

            case $node instanceof InlineFragmentNode:
                return $this->fieldComplexity($node->selectionSet);

            case $node instanceof FragmentSpreadNode:
                $fragment = $this->getFragment($node);

                if ($fragment !== null) {
                    return $this->fieldComplexity($fragment->selectionSet);
                }
        }

        return 0;
    }

    protected function fieldDefinition(FieldNode $field): ?FieldDefinition
    {
        foreach ($this->fieldNodeAndDefs[$this->getFieldName($field)] ?? [] as [$node, $def]) {
            if ($node === $field) {
                return $def;
            }
        }

        return null;
    }

    /**
     * Will the given field be executed at all, given the directives placed upon it?
     *
     * @throws \Exception
     * @throws \ReflectionException
     * @throws InvariantViolation
     */
    protected function directiveExcludesField(FieldNode $node): bool
    {
        foreach ($node->directives as $directiveNode) {
            if ($directiveNode->name->value === Directive::DEPRECATED_NAME) {
                return false;
            }

            [$errors, $variableValues] = Values::getVariableValues(
                $this->context->getSchema(),
                $this->variableDefs,
                $this->getRawVariableValues()
            );
            if ($errors !== null && $errors !== []) {
                throw new Error(implode("\n\n", array_map(static fn (Error $error): string => $error->getMessage(), $errors)));
            }

            if ($directiveNode->name->value === Directive::INCLUDE_NAME) {
                $includeArguments = Values::getArgumentValues(
                    Directive::includeDirective(),
                    $directiveNode,
                    $variableValues
                );
                assert(is_bool($includeArguments['if']), 'ensured by query validation');

                return ! $includeArguments['if'];
            }

            if ($directiveNode->name->value === Directive::SKIP_NAME) {
                $skipArguments = Values::getArgumentValues(
                    Directive::skipDirective(),
                    $directiveNode,
                    $variableValues
                );
                assert(is_bool($skipArguments['if']), 'ensured by query validation');

                return $skipArguments['if'];
            }
        }

        return false;
    }

    /** @return array<string, mixed> */
    public function getRawVariableValues(): array
    {
        return $this->rawVariableValues;
    }

    /** @param array<string, mixed>|null $rawVariableValues */
    public function setRawVariableValues(?array $rawVariableValues = null): void
    {
        $this->rawVariableValues = $rawVariableValues ?? [];
    }

    /**
     * @throws \Exception
     * @throws Error
     *
     * @return array<string, mixed>
     */
    protected function buildFieldArguments(FieldNode $node): array
    {
        $rawVariableValues = $this->getRawVariableValues();
        $fieldDef = $this->fieldDefinition($node);

        /** @var array<string, mixed> $args */
        $args = [];

        if ($fieldDef instanceof FieldDefinition) {
            [$errors, $variableValues] = Values::getVariableValues(
                $this->context->getSchema(),
                $this->variableDefs,
                $rawVariableValues
            );

            if (is_array($errors) && $errors !== []) {
                throw new Error(implode("\n\n", array_map(static fn ($error) => $error->getMessage(), $errors)));
            }

            $args = Values::getArgumentValues($fieldDef, $node, $variableValues);
        }

        return $args;
    }

    public function getMaxQueryComplexity(): int
    {
        return $this->maxQueryComplexity;
    }

    /**
     * Complexity of the first operation exceeding the defined limit, or, in case no operation
     * exceeds the limit, complexity of the last defined operation.
     */
    public function getQueryComplexity(): int
    {
        return $this->queryComplexity;
    }

    /**
     * Set max query complexity. If equal to 0 no check is done. Must be greater or equal to 0.
     *
     * @throws \InvalidArgumentException
     */
    public function setMaxQueryComplexity(int $maxQueryComplexity): void
    {
        $this->checkIfGreaterOrEqualToZero('maxQueryComplexity', $maxQueryComplexity);

        $this->maxQueryComplexity = $maxQueryComplexity;
    }

    public static function maxQueryComplexityErrorMessage(int $max, int $count): string
    {
        return "Max query complexity should be {$max} but got {$count}.";
    }

    protected function isEnabled(): bool
    {
        return $this->maxQueryComplexity !== self::DISABLED;
    }
}