<?php namespace GraphQL\Validator\Rules; use GraphQL\Error\Error; use GraphQL\Executor\Values; use GraphQL\Language\AST\FieldNode; use GraphQL\Language\AST\FragmentSpreadNode; use GraphQL\Language\AST\InlineFragmentNode; use GraphQL\Language\AST\Node; use GraphQL\Language\AST\NodeKind; use GraphQL\Language\AST\OperationDefinitionNode; use GraphQL\Language\AST\SelectionSetNode; use GraphQL\Language\Visitor; use GraphQL\Type\Definition\Directive; use GraphQL\Type\Definition\FieldDefinition; use GraphQL\Validator\ValidationContext; class QueryComplexity extends AbstractQuerySecurity { private $maxQueryComplexity; private $rawVariableValues = []; private $variableDefs; private $fieldNodeAndDefs; /** * @var ValidationContext */ private $context; public function __construct($maxQueryComplexity) { $this->setMaxQueryComplexity($maxQueryComplexity); } public static function maxQueryComplexityErrorMessage($max, $count) { return sprintf('Max query complexity should be %d but got %d.', $max, $count); } /** * Set max query complexity. If equal to 0 no check is done. Must be greater or equal to 0. * * @param $maxQueryComplexity */ public function setMaxQueryComplexity($maxQueryComplexity) { $this->checkIfGreaterOrEqualToZero('maxQueryComplexity', $maxQueryComplexity); $this->maxQueryComplexity = (int) $maxQueryComplexity; } public function getMaxQueryComplexity() { return $this->maxQueryComplexity; } public function setRawVariableValues(array $rawVariableValues = null) { $this->rawVariableValues = $rawVariableValues ?: []; } public function getRawVariableValues() { return $this->rawVariableValues; } public function getVisitor(ValidationContext $context) { $this->context = $context; $this->variableDefs = new \ArrayObject(); $this->fieldNodeAndDefs = new \ArrayObject(); $complexity = 0; return $this->invokeIfNeeded( $context, [ NodeKind::SELECTION_SET => function (SelectionSetNode $selectionSet) use ($context) { $this->fieldNodeAndDefs = $this->collectFieldASTsAndDefs( $context, $context->getParentType(), $selectionSet, null, $this->fieldNodeAndDefs ); }, NodeKind::VARIABLE_DEFINITION => function ($def) { $this->variableDefs[] = $def; return Visitor::skipNode(); }, NodeKind::OPERATION_DEFINITION => [ 'leave' => function (OperationDefinitionNode $operationDefinition) use ($context, &$complexity) { $errors = $context->getErrors(); if (empty($errors)) { $complexity = $this->fieldComplexity($operationDefinition, $complexity); if ($complexity > $this->getMaxQueryComplexity()) { $context->reportError( new Error($this->maxQueryComplexityErrorMessage($this->getMaxQueryComplexity(), $complexity)) ); } } }, ], ] ); } private function fieldComplexity($node, $complexity = 0) { if (isset($node->selectionSet) && $node->selectionSet instanceof SelectionSetNode) { foreach ($node->selectionSet->selections as $childNode) { $complexity = $this->nodeComplexity($childNode, $complexity); } } return $complexity; } private function nodeComplexity(Node $node, $complexity = 0) { switch ($node->kind) { case NodeKind::FIELD: /* @var FieldNode $node */ // default values $args = []; $complexityFn = FieldDefinition::DEFAULT_COMPLEXITY_FN; // calculate children complexity if needed $childrenComplexity = 0; // node has children? if (isset($node->selectionSet)) { $childrenComplexity = $this->fieldComplexity($node); } $astFieldInfo = $this->astFieldInfo($node); $fieldDef = $astFieldInfo[1]; if ($fieldDef instanceof FieldDefinition) { if ($this->directiveExcludesField($node)) { break; } $args = $this->buildFieldArguments($node); //get complexity fn using fieldDef complexity if (method_exists($fieldDef, 'getComplexityFn')) { $complexityFn = $fieldDef->getComplexityFn(); } } $complexity += call_user_func_array($complexityFn, [$childrenComplexity, $args]); break; case NodeKind::INLINE_FRAGMENT: /* @var InlineFragmentNode $node */ // node has children? if (isset($node->selectionSet)) { $complexity = $this->fieldComplexity($node, $complexity); } break; case NodeKind::FRAGMENT_SPREAD: /* @var FragmentSpreadNode $node */ $fragment = $this->getFragment($node); if (null !== $fragment) { $complexity = $this->fieldComplexity($fragment, $complexity); } break; } return $complexity; } private function astFieldInfo(FieldNode $field) { $fieldName = $this->getFieldName($field); $astFieldInfo = [null, null]; if (isset($this->fieldNodeAndDefs[$fieldName])) { foreach ($this->fieldNodeAndDefs[$fieldName] as $astAndDef) { if ($astAndDef[0] == $field) { $astFieldInfo = $astAndDef; break; } } } return $astFieldInfo; } private function buildFieldArguments(FieldNode $node) { $rawVariableValues = $this->getRawVariableValues(); $astFieldInfo = $this->astFieldInfo($node); $fieldDef = $astFieldInfo[1]; $args = []; if ($fieldDef instanceof FieldDefinition) { $variableValuesResult = Values::getVariableValues( $this->context->getSchema(), $this->variableDefs, $rawVariableValues ); if ($variableValuesResult['errors']) { throw new Error(implode("\n\n", array_map( function ($error) { return $error->getMessage(); } , $variableValuesResult['errors']))); } $variableValues = $variableValuesResult['coerced']; $args = Values::getArgumentValues($fieldDef, $node, $variableValues); } return $args; } private function directiveExcludesField(FieldNode $node) { foreach ($node->directives as $directiveNode) { if ($directiveNode->name->value === 'deprecated') { return false; } $variableValuesResult = Values::getVariableValues( $this->context->getSchema(), $this->variableDefs, $this->getRawVariableValues() ); if ($variableValuesResult['errors']) { throw new Error(implode("\n\n", array_map( function ($error) { return $error->getMessage(); } , $variableValuesResult['errors']))); } $variableValues = $variableValuesResult['coerced']; if ($directiveNode->name->value === 'include') { $directive = Directive::includeDirective(); $directiveArgs = Values::getArgumentValues($directive, $directiveNode, $variableValues); return !$directiveArgs['if']; } else { $directive = Directive::skipDirective(); $directiveArgs = Values::getArgumentValues($directive, $directiveNode, $variableValues); return $directiveArgs['if']; } } } protected function isEnabled() { return $this->getMaxQueryComplexity() !== static::DISABLED; } }