Python高质量函数编写指南

2024-04-18 15:47:40 浏览数 (2)

The Ultimate Guide to Writing Functions 1.视频 https://www.youtube.com/watch?v=yatgY4NpZXE 2.代码 https://github.com/ArjanCodes/2022-funcguide

Python高质量函数编写指南

1. 一次做好一件事

代码语言:javascript复制
from dataclasses import dataclass
from datetime import datetime


@dataclass
class Customer:
    name: str
    phone: str
    cc_number: str
    cc_exp_month: int
    cc_exp_year: int
    cc_valid: bool = False

# validate_card函数做了太多事情
def validate_card(customer: Customer) -> bool:
    def digits_of(number: str) -> list[int]:
        return [int(d) for d in number]

    digits = digits_of(customer.cc_number)
    odd_digits = digits[-1::-2]
    even_digits = digits[-2::-2]
    checksum = 0
    checksum  = sum(odd_digits)
    for digit in even_digits:
        checksum  = sum(digits_of(str(digit * 2)))

    customer.cc_valid = (
        checksum % 10 == 0
        and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
    )
    return customer.cc_valid


def main() -> None:
    alice = Customer(
        name="Alice",
        phone="2341",
        cc_number="1249190007575069",
        cc_exp_month=1,
        cc_exp_year=2024,
    )
    is_valid = validate_card(alice)
    print(f"Is Alice's card valid? {is_valid}")
    print(alice)


if __name__ == "__main__":
    main()

我们发现validate_card函数做了两件事:验证数字和有效、验证时间有效。 我们把验证数字和拆分出来一个函数luhn_checksum, 并在validate_card中调用。 修改后:

代码语言:javascript复制
from dataclasses import dataclass
from datetime import datetime

# 验证和 函数
def luhn_checksum(card_number: str) -> bool:
    def digits_of(number: str) -> list[int]:
        return [int(d) for d in number]

    digits = digits_of(card_number)
    odd_digits = digits[-1::-2]
    even_digits = digits[-2::-2]
    checksum = 0
    checksum  = sum(odd_digits)
    for digit in even_digits:
        checksum  = sum(digits_of(str(digit * 2)))
    return checksum % 10 == 0


@dataclass
class Customer:
    name: str
    phone: str
    ...


def validate_card(customer: Customer) -> bool:
    customer.cc_valid = (
        luhn_checksum(customer.cc_number)
        and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
    )
    return customer.cc_valid

2. 分离命令和查询(command and query)

validate_card 中同时进行了查询赋值两个操作,这样不好。 我们将查询和赋值拆分成两个步骤。 validate_card只返回卡是否有效,而赋值操作alice.cc_valid = validate_card(alice) 移动到了主函数中。

代码语言:javascript复制
from dataclasses import dataclass
from datetime import datetime


def luhn_checksum(card_number: str) -> bool:
    def digits_of(number: str) -> list[int]:
        return [int(d) for d in number]

    digits = digits_of(card_number)
    odd_digits = digits[-1::-2]
    even_digits = digits[-2::-2]
    checksum = 0
    checksum  = sum(odd_digits)
    for digit in even_digits:
        checksum  = sum(digits_of(str(digit * 2)))
    return checksum % 10 == 0


@dataclass
class Customer:
    name: str
    phone: str
    cc_number: str
    cc_exp_month: int
    cc_exp_year: int
    cc_valid: bool = False

# 查询
def validate_card(customer: Customer) -> bool:
    return (
        luhn_checksum(customer.cc_number)
        and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now()
    )


def main() -> None:
    alice = Customer(
        name="Alice",
        phone="2341",
        cc_number="1249190007575069",
        cc_exp_month=1,
        cc_exp_year=2024,
    )
    # 赋值
    alice.cc_valid = validate_card(alice) 
    print(f"Is Alice's card valid? {alice.cc_valid}")
    print(alice)


if __name__ == "__main__":
    main()

3. 只请求你需要的

函数validate_card实际上只需要3个参数(而不需要整个Customer对象)。 因此只请求3个参数:def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool: ``

代码语言:javascript复制
from dataclasses import dataclass
from datetime import datetime


def luhn_checksum(card_number: str) -> bool:
    def digits_of(number: str) -> list[int]:
        return [int(d) for d in number]

    digits = digits_of(card_number)
    odd_digits = digits[-1::-2]
    even_digits = digits[-2::-2]
    checksum = 0
    checksum  = sum(odd_digits)
    for digit in even_digits:
        checksum  = sum(digits_of(str(digit * 2)))
    return checksum % 10 == 0


@dataclass
class Customer:
    name: str
    phone: str
    cc_number: str
    cc_exp_month: int
    cc_exp_year: int
    cc_valid: bool = False

# 只请求你需要的参数
def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool:
    return luhn_checksum(number) and datetime(exp_year, exp_month, 1) > datetime.now()


def main() -> None:
    alice = Customer(
        name="Alice",
        phone="2341",
        cc_number="1249190007575069",
        cc_exp_month=1,
        cc_exp_year=2024,
    )
    alice.cc_valid = validate_card(
        number=alice.cc_number,
        exp_month=alice.cc_exp_month,
        exp_year=alice.cc_exp_year,
    )
    print(f"Is Alice's card valid? {alice.cc_valid}")
    print(alice)


if __name__ == "__main__":
    main()

4. 保持最小参数量

参数量很多时,调用时传参会比较麻烦。另一方面,函数需要很多参数,则暗示该函数可能做了很多事情。

下面我们抽象出Card 类, 减少了Customervalidae_card的参数量。

代码语言:javascript复制
from dataclasses import dataclass
from datetime import datetime
from typing import Protocol


def luhn_checksum(card_number: str) -> bool:
    def digits_of(number: str) -> list[int]:
        return [int(d) for d in number]

    digits = digits_of(card_number)
    odd_digits = digits[-1::-2]
    even_digits = digits[-2::-2]
    checksum = 0
    checksum  = sum(odd_digits)
    for digit in even_digits:
        checksum  = sum(digits_of(str(digit * 2)))
    return checksum % 10 == 0


@dataclass
class Card:
    number: str
    exp_month: int
    exp_year: int
    valid: bool = False


@dataclass
class Customer:
    name: str
    phone: str
    card: Card
    card_valid: bool = False


class CardInfo(Protocol):
    @property
    def number(self) -> str:
        ...

    @property
    def exp_month(self) -> int:
        ...

    @property
    def exp_year(self) -> int:
        ...


def validate_card(card: CardInfo) -> bool:
    return (
        luhn_checksum(card.number)
        and datetime(card.exp_year, card.exp_month, 1) > datetime.now()
    )


def main() -> None:
    card = Card(number="1249190007575069", exp_month=1, exp_year=2024)
    alice = Customer(name="Alice", phone="2341", card=card)  # 现在传入card,而不是3个参数
    card.valid = validate_card(card) # 传入card
    print(f"Is Alice's card valid? {card.valid}")
    print(alice)


if __name__ == "__main__":
    main()

5. 不要在同一个地方创建并使用对象

不要再函数内创建对象并使用,更好的方式是在外面创建对象并作为参数传递给函数。

代码语言:javascript复制
import logging


class StripePaymentHandler:
    def handle_payment(self, amount: int) -> None:
        logging.info(f"Charging ${amount/100:.2f} using Stripe")


PRICES = {
    "burger": 10_00,
    "fries": 5_00,
    "drink": 2_00,
    "salad": 15_00,
}

# !!
def order_food(items: list[str]) -> None:
    total = sum(PRICES[item] for item in items)
    logging.info(f"Order total is ${total/100:.2f}.")
    payment_handler = StripePaymentHandler() # ... 创建对象
    payment_handler.handle_payment(total)  # 使用对象
    logging.info("Order completed.")


def main() -> None:
    logging.basicConfig(level=logging.INFO)
    order_food(["burger", "fries", "drink"])


if __name__ == "__main__":
    main()

修改后:

代码语言:javascript复制
import logging
from typing import Protocol


class StripePaymentHandler:
    def handle_payment(self, amount: int) -> None:
        logging.info(f"Charging ${amount/100:.2f} using Stripe")


PRICES = {
    "burger": 10_00,
    "fries": 5_00,
    "drink": 2_00,
    "salad": 15_00,
}


class PaymentHandler(Protocol):
    def handle_payment(self, amount: int) -> None:
        ...

# !!  现在通过参数传入对象
def order_food(items: list[str], payment_handler: PaymentHandler) -> None:
    total = sum(PRICES[item] for item in items)
    logging.info(f"Order total is ${total/100:.2f}.")
    payment_handler.handle_payment(total) # 
    logging.info("Order completed.")


def main() -> None:
    logging.basicConfig(level=logging.INFO)
    order_food(["burger", "salad", "drink"], StripePaymentHandler())


if __name__ == "__main__":
    main()

6. 不要用flag参数

flag参数意味着函数处理两种情况,函数会变得复杂。建议将两者情况拆分成单独的函数。

代码语言:javascript复制
from dataclasses import dataclass
from enum import StrEnum, auto

FIXED_VACATION_DAYS_PAYOUT = 5


class Role(StrEnum):
    PRESIDENT = auto()
    VICEPRESIDENT = auto()
    MANAGER = auto()
    LEAD = auto()
    ENGINEER = auto()
    INTERN = auto()


@dataclass
class Employee:
    name: str
    role: Role
    vacation_days: int = 25

    def take_a_holiday(self, payout: bool, nr_days: int = 1) -> None:
        if payout:
            if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT:
                raise ValueError(
                    f"You don't have enough holidays left over for a payout.
                        Remaining holidays: {self.vacation_days}."
                )
            self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT
            print(f"Paying out a holiday. Holidays left: {self.vacation_days}")
        else:
            if self.vacation_days < nr_days:
                raise ValueError(
                    "You don't have any holidays left. Now back to work, you!"
                )
            self.vacation_days -= nr_days
            print("Have fun on your holiday. Don't forget to check your emails!")


def main() -> None:
    employee = Employee(name="John Doe", role=Role.ENGINEER)
    employee.take_a_holiday(True)


if __name__ == "__main__":
    main()

修改后:

代码语言:javascript复制
from dataclasses import dataclass
from enum import StrEnum, auto

FIXED_VACATION_DAYS_PAYOUT = 5


class Role(StrEnum):
    PRESIDENT = auto()
    VICEPRESIDENT = auto()
    MANAGER = auto()
    LEAD = auto()
    ENGINEER = auto()
    INTERN = auto()


@dataclass
class Employee:
    name: str
    role: Role
    vacation_days: int = 25

    def payout_holiday(self) -> None:
        if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT:
            raise ValueError(
                f"You don't have enough holidays left over for a payout.
                    Remaining holidays: {self.vacation_days}."
            )
        self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT
        print(f"Paying out a holiday. Holidays left: {self.vacation_days}")

    def take_holiday(self, nr_days: int = 1) -> None:
        if self.vacation_days < nr_days:
            raise ValueError("You don't have any holidays left. Now back to work, you!")
        self.vacation_days -= nr_days
        print("Have fun on your holiday. Don't forget to check your emails!")


def main() -> None:
    employee = Employee(name="John Doe", role=Role.ENGINEER)
    employee.payout_holiday()


if __name__ == "__main__":
    main()

7. 函数也是对象

函数也是对象,因此可以作为参数传递,作为函数返回值。

代码语言:javascript复制
import logging
from functools import partial
from typing import Callable


def handle_payment_stripe(amount: int) -> None:
    logging.info(f"Charging ${amount/100:.2f} using Stripe")


PRICES = {
    "burger": 10_00,
    "fries": 5_00,
    "drink": 2_00,
    "salad": 15_00,
}


HandlePaymentFn = Callable[[int], None]

# 函数作为参数
def order_food(items: list[str], payment_handler: HandlePaymentFn) -> None:
    total = sum(PRICES[item] for item in items)
    logging.info(f"Order total is ${total/100:.2f}.")
    payment_handler(total)
    logging.info("Order completed.")


order_food_stripe = partial(order_food, payment_handler=handle_payment_stripe)


def main() -> None:
    logging.basicConfig(level=logging.INFO)
    # order_food(["burger", "salad", "drink"], handle_payment_stripe)
    order_food_stripe(["burger", "salad", "drink"])


if __name__ == "__main__":
    main()

0 人点赞